Allow Mobject.get_family(recurse) for common recursive methods

This commit is contained in:
Grant Sanderson 2021-01-11 17:03:12 -10:00
parent 4d9498322e
commit 6b451dcc22
4 changed files with 62 additions and 78 deletions

View file

@ -144,8 +144,8 @@ class Mobject(object):
parent.assemble_family()
return self
def get_family(self):
return self.family
def get_family(self, recurse=True):
return self.family if recurse else [self]
def family_members_with_points(self):
return [m for m in self.get_family() if m.has_points()]
@ -672,31 +672,30 @@ class Mobject(object):
return self
# Color functions
def set_rgba_array(self, color=None, opacity=None, name="rgbas", family=True):
def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True):
# TODO, account for if color or opacity are tuples
rgb = color_to_rgb(color) if color else None
mobs = self.get_family() if family else [self]
for mob in mobs:
for mob in self.get_family(recurse):
if rgb is not None:
mob.data[name][:, :3] = rgb
if opacity is not None:
mob.data[name][:, 3] = opacity
return self
def set_color(self, color, opacity=None, family=True):
self.set_rgba_array(color, opacity, family=False)
def set_color(self, color, opacity=None, recurse=True):
self.set_rgba_array(color, opacity, recurse=False)
# Recurse to submobjects differently from how set_rgba_array
# in case they implement set_color differently
if family:
if recurse:
for submob in self.submobjects:
submob.set_color(color, family=True)
submob.set_color(color, recurse=True)
return self
def set_opacity(self, opacity, family=True):
self.set_rgba_array(color=None, opacity=opacity, family=False)
if family:
def set_opacity(self, opacity, recurse=True):
self.set_rgba_array(color=None, opacity=opacity, recurse=False)
if recurse:
for submob in self.submobjects:
submob.set_opacity(opacity, family=True)
submob.set_opacity(opacity, recurse=True)
return self
def get_color(self):
@ -719,30 +718,26 @@ class Mobject(object):
new_colors = color_gradient(colors, len(mobs))
for mob, color in zip(mobs, new_colors):
mob.set_color(color, family=False)
mob.set_color(color, recurse=False)
return self
def fade(self, darkness=0.5, family=True):
self.set_opacity(1.0 - darkness, family=family)
def fade(self, darkness=0.5, recurse=True):
self.set_opacity(1.0 - darkness, recurse=recurse)
def get_gloss(self):
return self.gloss
def set_gloss(self, gloss, family=True):
self.gloss = gloss
if family:
for submob in self.submobjects:
submob.set_gloss(gloss, family)
def set_gloss(self, gloss, recurse=True):
for mob in self.get_family(recurse):
mob.gloss = gloss
return self
def get_shadow(self):
return self.shadow
def set_shadow(self, shadow, family=True):
self.shadow = shadow
if family:
for submob in self.submobjects:
submob.set_shadow(shadow, family)
def set_shadow(self, shadow, recurse=True):
for mob in self.get_family(recurse):
mob.shadow = shadow
return self
##
@ -1050,10 +1045,10 @@ class Mobject(object):
self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
return self
def shuffle(self, recursive=False):
if recursive:
def shuffle(self, recurse=False):
if recurse:
for submob in self.submobjects:
submob.shuffle(recursive=True)
submob.shuffle(recurse=True)
random.shuffle(self.submobjects)
return self

View file

@ -38,10 +38,9 @@ class ImageMobject(Mobject):
self.set_width(2 * size[0] / size[1], stretch=True)
self.set_height(self.height)
def set_opacity(self, opacity, family=True):
def set_opacity(self, opacity, recurse=True):
# TODO, account for opacity coming in as an array
mobs = self.get_family() if family else [self]
for mob in mobs:
for mob in self.get_family(recurse):
mob.data["opacity"][:, 0] = opacity
return self

View file

@ -221,9 +221,8 @@ class TexturedSurface(ParametricSurface):
def init_colors(self):
pass
def set_opacity(self, opacity, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
def set_opacity(self, opacity, recurse=True):
for mob in self.get_family(recurse):
mob.data["opacity"][:] = opacity
return self

View file

@ -77,7 +77,7 @@ class VMobject(Mobject):
self.needs_new_triangulation = True
self.triangulation = np.zeros(0, dtype='i4')
super().__init__(**kwargs)
self.lock_unit_normal(family=False)
self.lock_unit_normal(recurse=False)
def get_group_class(self):
return VGroup
@ -112,20 +112,18 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke)
return self
def set_fill(self, color=None, opacity=None, family=True):
self.set_rgba_array(color, opacity, 'fill_rgba', family)
def set_fill(self, color=None, opacity=None, recurse=True):
self.set_rgba_array(color, opacity, 'fill_rgba', recurse)
def set_stroke(self, color=None, width=None, opacity=None, background=None, family=True):
self.set_rgba_array(color, opacity, 'stroke_rgba', family)
def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True):
self.set_rgba_array(color, opacity, 'stroke_rgba', recurse)
if width is not None:
mobs = self.get_family() if family else [self]
for mob in mobs:
for mob in self.get_family(recurse):
mob.data['stroke_width'][:] = width
if background is not None:
mobs = self.get_family() if family else [self]
for mob in mobs:
for mob in self.get_family(recurse):
mob.draw_stroke_behind_fill = background
return self
@ -139,14 +137,14 @@ class VMobject(Mobject):
stroke_width=None,
gloss=None,
shadow=None,
family=True):
recurse=True):
if fill_rgba is not None:
self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
else:
self.set_fill(
color=fill_color,
opacity=fill_opacity,
family=family
recurse=recurse
)
if stroke_rgba is not None:
@ -157,13 +155,13 @@ class VMobject(Mobject):
color=stroke_color,
width=stroke_width,
opacity=stroke_opacity,
family=family,
recurse=recurse,
)
if gloss is not None:
self.set_gloss(gloss, family=family)
self.set_gloss(gloss, recurse=recurse)
if shadow is not None:
self.set_shadow(shadow, family=family)
self.set_shadow(shadow, recurse=recurse)
return self
def get_style(self):
@ -175,9 +173,9 @@ class VMobject(Mobject):
"shadow": self.get_shadow(),
}
def match_style(self, vmobject, family=True):
self.set_style(**vmobject.get_style(), family=False)
if family:
def match_style(self, vmobject, recurse=True):
self.set_style(**vmobject.get_style(), recurse=False)
if recurse:
# Does its best to match up submobject lists, and
# match styles accordingly
submobs1, submobs2 = self.submobjects, vmobject.submobjects
@ -189,27 +187,27 @@ class VMobject(Mobject):
sm1.match_style(sm2)
return self
def set_color(self, color, family=True):
self.set_fill(color, family=family)
self.set_stroke(color, family=family)
def set_color(self, color, recurse=True):
self.set_fill(color, recurse=recurse)
self.set_stroke(color, recurse=recurse)
return self
def set_opacity(self, opacity, family=True):
self.set_fill(opacity=opacity, family=family)
self.set_stroke(opacity=opacity, family=family)
def set_opacity(self, opacity, recurse=True):
self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def fade(self, darkness=0.5, family=True):
def fade(self, darkness=0.5, recurse=True):
factor = 1.0 - darkness
self.set_fill(
opacity=factor * self.get_fill_opacity(),
family=False,
recurse=False,
)
self.set_stroke(
opacity=factor * self.get_stroke_opacity(),
family=False,
recurse=False,
)
super().fade(darkness, family)
super().fade(darkness, recurse)
return self
def get_fill_colors(self):
@ -274,9 +272,8 @@ class VMobject(Mobject):
return self.get_fill_opacity()
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke=True, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
def set_flat_stroke(self, flat_stroke=True, recurse=True):
for mob in self.get_family(recurse):
mob.flat_stroke = flat_stroke
return self
@ -378,12 +375,8 @@ class VMobject(Mobject):
self.get_points()[0], self.get_points()[-1]
)
def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, family=True):
if family:
vmobs = self.family_members_with_points()
else:
vmobs = [self] if self.has_points() else []
def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True):
vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()]
for vmob in vmobs:
new_points = []
for tup in vmob.get_bezier_tuples():
@ -617,9 +610,8 @@ class VMobject(Mobject):
points[2] - points[1],
)
def lock_unit_normal(self, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
def lock_unit_normal(self, recurse=True):
for mob in self.get_family(recurse):
mob.unit_normal_locked = False
mob.saved_unit_normal = mob.get_unit_normal()
mob.unit_normal_locked = True
@ -681,9 +673,8 @@ class VMobject(Mobject):
vmobject.set_points(np.vstack(new_subpaths2))
return self
def insert_n_curves(self, n, family=True):
mobs = self.get_family() if family else [self]
for mob in mobs:
def insert_n_curves(self, n, recurse=True):
for mob in self.get_family(recurse):
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
# TODO, this should happen in insert_n_curves_to_point_list
@ -1005,4 +996,4 @@ class DashedVMobject(VMobject):
])
# Family is already taken care of by get_subcurve
# implementation
self.match_style(vmobject, family=False)
self.match_style(vmobject, recurse=False)