Cleanup VMobject shader wrapper methods

Deleting those which are no longer needed
This commit is contained in:
Grant Sanderson 2023-01-16 11:50:31 -08:00
parent 74b42a6eb5
commit bdcfbc39ec

View file

@ -64,18 +64,8 @@ class VMobject(Mobject):
('orientation', np.float32, (1,)),
('vert_index', np.float32, (1,)),
])
fill_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)),
('fill_rgba', np.float32, (4,)),
('orientation', np.float32, (1,)),
('vert_index', np.float32, (1,)),
])
stroke_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)),
('stroke_rgba', np.float32, (4,)),
('stroke_width', np.float32, (1,)),
('joint_angle', np.float32, (1,)),
])
fill_data_names = ['point', 'fill_rgba', 'orientation', 'vert_index']
stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_angle']
fill_render_primitive: int = moderngl.TRIANGLES
stroke_render_primitive: int = moderngl.TRIANGLE_STRIP
@ -1175,17 +1165,25 @@ class VMobject(Mobject):
# For shaders
def init_shader_data(self):
self.fill_data = np.zeros(0, dtype=self.fill_dtype)
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype)
dtype = self.shader_dtype
fill_dtype, stroke_dtype = (
np.dtype([
(name, dtype[name].base, dtype[name].shape)
for name in names
])
for names in [self.fill_data_names, self.stroke_data_names]
)
fill_data = np.zeros(0, dtype=fill_dtype)
stroke_data = np.zeros(0, dtype=stroke_dtype)
self.fill_shader_wrapper = ShaderWrapper(
vert_data=self.fill_data,
vert_data=fill_data,
vert_indices=np.zeros(0, dtype='i4'),
uniforms=self.uniforms,
shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive,
)
self.stroke_shader_wrapper = ShaderWrapper(
vert_data=self.stroke_data,
vert_data=stroke_data,
uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive,
@ -1193,89 +1191,51 @@ class VMobject(Mobject):
self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy()
def refresh_shader_wrapper_id(self):
for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]:
for wrapper in self.get_shader_wrapper_list():
wrapper.refresh_id()
return self
def get_fill_shader_wrapper(self) -> ShaderWrapper:
self.fill_shader_wrapper.vert_indices = self.get_triangulation()
self.fill_shader_wrapper.vert_data = self.get_fill_shader_data()
self.fill_shader_wrapper.uniforms = self.get_uniforms()
self.fill_shader_wrapper.depth_test = self.depth_test
return self.fill_shader_wrapper
def get_stroke_shader_wrapper(self) -> ShaderWrapper:
self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data()
self.stroke_shader_wrapper.uniforms = self.get_uniforms()
self.stroke_shader_wrapper.depth_test = self.depth_test
return self.stroke_shader_wrapper
def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
family = self.family_members_with_points()
fill_names = self.fill_data_names
stroke_names = self.stroke_data_names
# Build up data lists
fill_submobs = []
stroke_submobs = []
bstroke_submobs = []
for submob in self.family_members_with_points():
fill_datas = []
fill_indices = []
stroke_datas = []
back_stroke_data = []
for submob in family:
if submob.has_fill():
fill_submobs.append(submob)
fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation())
if submob.has_stroke():
if submob.draw_stroke_behind_fill:
bstroke_submobs.append(submob)
lst = back_stroke_data
else:
stroke_submobs.append(submob)
fill_names = list(self.fill_data.dtype.names)
self.fill_shader_wrapper.read_in(
[sm.data[fill_names] for sm in fill_submobs],
[sm.get_fill_shader_vert_indices() for sm in fill_submobs],
)
self.stroke_shader_wrapper.read_in(
[sm.get_stroke_shader_data() for sm in stroke_submobs],
)
self.back_stroke_shader_wrapper.read_in(
[sm.get_stroke_shader_data() for sm in bstroke_submobs],
)
lst = stroke_datas
lst.append(submob.data[stroke_names])
# Set data array to be one longer than number of points,
# with a dummy vertex added at the end. This is to ensure
# it can be safely stacked onto other stroke data arrays.
lst.append(submob.data[stroke_names][-1:])
shader_wrappers = [
self.back_stroke_shader_wrapper,
self.fill_shader_wrapper,
self.stroke_shader_wrapper
self.back_stroke_shader_wrapper.read_in(back_stroke_data),
self.fill_shader_wrapper.read_in(fill_datas, fill_indices),
self.stroke_shader_wrapper.read_in(stroke_datas),
]
for sw in shader_wrappers:
# TODO, handle depth test and uniforms...
pass
# Assume uniforms of the first family member
sw.uniforms = family[0].get_uniforms()
sw.depth_test = family[0].depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
def get_stroke_shader_data(self) -> np.ndarray:
# Set data array to be one longer than number of points,
# with a dummy vertex added at the end. This is to ensure
# it can be safely stacked onto other stroke data arrays.
n = len(self.data)
size = n + 1 if n > 0 else 0
self.stroke_data = resize_array(self.stroke_data, size)
if n == 0:
return self.stroke_data
self.get_joint_angles() # Recomputes, only if refresh is needed
for key in self.stroke_data.dtype.names:
self.stroke_data[key][:n] = self.data[key]
self.stroke_data[-1] = self.stroke_data[-2]
return self.stroke_data
def get_fill_shader_data(self) -> np.ndarray:
self.fill_data = resize_array(self.fill_data, len(self.data))
for key in self.fill_data.dtype.names:
self.fill_data[key][:] = self.data[key]
return self.fill_data
def refresh_shader_data(self):
self.get_fill_shader_data()
self.get_stroke_shader_data()
def get_fill_shader_vert_indices(self) -> np.ndarray:
return self.get_triangulation()
for submob in self.get_family():
submob.get_joint_angles()
self.get_shader_wrapper_list()
class VGroup(VMobject):