mirror of
https://github.com/3b1b/manim.git
synced 2025-11-14 06:27:46 +00:00
Clean up and fix VMobject.get_shader_wrapper_list
This commit is contained in:
parent
25ac5f3507
commit
3165a28cd0
1 changed files with 19 additions and 28 deletions
|
|
@ -148,11 +148,6 @@ class VMobject(Mobject):
|
|||
raise Exception("All submobjects must be of type VMobject")
|
||||
super().add(*vmobjects)
|
||||
|
||||
def copy(self, deep: bool = False) -> VMobject:
|
||||
result = super().copy(deep)
|
||||
result.shader_wrapper_list = [sw.copy() for sw in self.shader_wrapper_list]
|
||||
return result
|
||||
|
||||
# Colors
|
||||
def init_colors(self):
|
||||
self.set_fill(
|
||||
|
|
@ -809,6 +804,11 @@ class VMobject(Mobject):
|
|||
# Alignment
|
||||
def align_points(self, vmobject: VMobject):
|
||||
if self.get_num_points() == len(vmobject.get_points()):
|
||||
# If both have fill, and they have the same shape, just
|
||||
# give them the same triangulation so that it's not recalculated
|
||||
# needlessly throughout an animation
|
||||
if self.has_fill() and vmobject.has_fill() and self.has_same_shape_as(vmobject):
|
||||
vmobject.triangulation = self.triangulation
|
||||
return
|
||||
|
||||
for mob in self, vmobject:
|
||||
|
|
@ -1077,14 +1077,6 @@ class VMobject(Mobject):
|
|||
render_primitive=self.render_primitive,
|
||||
)
|
||||
|
||||
self.shader_wrapper_list = [
|
||||
self.stroke_shader_wrapper.copy(), # Use for back stroke
|
||||
self.fill_shader_wrapper.copy(),
|
||||
self.stroke_shader_wrapper.copy(),
|
||||
]
|
||||
for sw in self.shader_wrapper_list:
|
||||
sw.uniforms = self.uniforms
|
||||
|
||||
def refresh_shader_wrapper_id(self):
|
||||
for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]:
|
||||
wrapper.refresh_id()
|
||||
|
|
@ -1107,30 +1099,29 @@ class VMobject(Mobject):
|
|||
# Build up data lists
|
||||
fill_shader_wrappers = []
|
||||
stroke_shader_wrappers = []
|
||||
back_stroke_shader_wrappers = []
|
||||
for submob in self.family_members_with_points():
|
||||
if submob.has_fill():
|
||||
fill_shader_wrappers.append(submob.get_fill_shader_wrapper())
|
||||
if submob.has_stroke():
|
||||
ssw = submob.get_stroke_shader_wrapper()
|
||||
stroke_shader_wrappers.append(submob.get_stroke_shader_wrapper())
|
||||
if submob.draw_stroke_behind_fill:
|
||||
back_stroke_shader_wrappers.append(ssw)
|
||||
else:
|
||||
stroke_shader_wrappers.append(ssw)
|
||||
self.draw_stroke_behind_fill = True
|
||||
|
||||
# Combine data lists
|
||||
sw_lists = [
|
||||
back_stroke_shader_wrappers,
|
||||
fill_shader_wrappers,
|
||||
stroke_shader_wrappers,
|
||||
]
|
||||
for sw, sw_list in zip(self.shader_wrapper_list, sw_lists):
|
||||
self_sws = [self.fill_shader_wrapper, self.stroke_shader_wrapper]
|
||||
sw_lists = [fill_shader_wrappers, stroke_shader_wrappers]
|
||||
for sw, sw_list in zip(self_sws, sw_lists):
|
||||
if not sw_list:
|
||||
sw.vert_data = resize_array(sw.vert_data, 0)
|
||||
continue
|
||||
if sw is sw_list[0]:
|
||||
sw.combine_with(*sw_list[1:])
|
||||
else:
|
||||
sw.read_in(*sw_list)
|
||||
sw.depth_test = any(sw.depth_test for sw in sw_list)
|
||||
sw.uniforms.update(sw_list[0].uniforms)
|
||||
return list(filter(lambda sw: len(sw.vert_data) > 0, self.shader_wrapper_list))
|
||||
if self.draw_stroke_behind_fill:
|
||||
self_sws.reverse()
|
||||
return [sw for sw in self_sws if len(sw.vert_data) > 0]
|
||||
|
||||
def get_stroke_shader_data(self) -> np.ndarray:
|
||||
points = self.get_points()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue