diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 7dae8b4b..a6ebf03e 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -64,18 +64,8 @@ class VMobject(Mobject): ('orientation', np.float32, (1,)), ('vert_index', np.float32, (1,)), ]) - fill_dtype: np.dtype = np.dtype([ - ('point', np.float32, (3,)), - ('fill_rgba', np.float32, (4,)), - ('orientation', np.float32, (1,)), - ('vert_index', np.float32, (1,)), - ]) - stroke_dtype: np.dtype = np.dtype([ - ('point', np.float32, (3,)), - ('stroke_rgba', np.float32, (4,)), - ('stroke_width', np.float32, (1,)), - ('joint_angle', np.float32, (1,)), - ]) + fill_data_names = ['point', 'fill_rgba', 'orientation', 'vert_index'] + stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_angle'] fill_render_primitive: int = moderngl.TRIANGLES stroke_render_primitive: int = moderngl.TRIANGLE_STRIP @@ -1175,17 +1165,25 @@ class VMobject(Mobject): # For shaders def init_shader_data(self): - self.fill_data = np.zeros(0, dtype=self.fill_dtype) - self.stroke_data = np.zeros(0, dtype=self.stroke_dtype) + dtype = self.shader_dtype + fill_dtype, stroke_dtype = ( + np.dtype([ + (name, dtype[name].base, dtype[name].shape) + for name in names + ]) + for names in [self.fill_data_names, self.stroke_data_names] + ) + fill_data = np.zeros(0, dtype=fill_dtype) + stroke_data = np.zeros(0, dtype=stroke_dtype) self.fill_shader_wrapper = ShaderWrapper( - vert_data=self.fill_data, + vert_data=fill_data, vert_indices=np.zeros(0, dtype='i4'), uniforms=self.uniforms, shader_folder=self.fill_shader_folder, render_primitive=self.fill_render_primitive, ) self.stroke_shader_wrapper = ShaderWrapper( - vert_data=self.stroke_data, + vert_data=stroke_data, uniforms=self.uniforms, shader_folder=self.stroke_shader_folder, render_primitive=self.stroke_render_primitive, @@ -1193,89 +1191,51 @@ class VMobject(Mobject): self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy() def refresh_shader_wrapper_id(self): - for wrapper in [self.fill_shader_wrapper, self.stroke_shader_wrapper]: + for wrapper in self.get_shader_wrapper_list(): wrapper.refresh_id() return self - def get_fill_shader_wrapper(self) -> ShaderWrapper: - self.fill_shader_wrapper.vert_indices = self.get_triangulation() - self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() - self.fill_shader_wrapper.uniforms = self.get_uniforms() - self.fill_shader_wrapper.depth_test = self.depth_test - return self.fill_shader_wrapper - - def get_stroke_shader_wrapper(self) -> ShaderWrapper: - self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() - self.stroke_shader_wrapper.uniforms = self.get_uniforms() - self.stroke_shader_wrapper.depth_test = self.depth_test - return self.stroke_shader_wrapper - def get_shader_wrapper_list(self) -> list[ShaderWrapper]: + family = self.family_members_with_points() + fill_names = self.fill_data_names + stroke_names = self.stroke_data_names + # Build up data lists - fill_submobs = [] - stroke_submobs = [] - bstroke_submobs = [] - for submob in self.family_members_with_points(): + fill_datas = [] + fill_indices = [] + stroke_datas = [] + back_stroke_data = [] + for submob in family: if submob.has_fill(): - fill_submobs.append(submob) + fill_datas.append(submob.data[fill_names]) + fill_indices.append(submob.get_triangulation()) if submob.has_stroke(): if submob.draw_stroke_behind_fill: - bstroke_submobs.append(submob) + lst = back_stroke_data else: - stroke_submobs.append(submob) - - fill_names = list(self.fill_data.dtype.names) - self.fill_shader_wrapper.read_in( - [sm.data[fill_names] for sm in fill_submobs], - [sm.get_fill_shader_vert_indices() for sm in fill_submobs], - ) - self.stroke_shader_wrapper.read_in( - [sm.get_stroke_shader_data() for sm in stroke_submobs], - ) - self.back_stroke_shader_wrapper.read_in( - [sm.get_stroke_shader_data() for sm in bstroke_submobs], - ) - + lst = stroke_datas + lst.append(submob.data[stroke_names]) + # Set data array to be one longer than number of points, + # with a dummy vertex added at the end. This is to ensure + # it can be safely stacked onto other stroke data arrays. + lst.append(submob.data[stroke_names][-1:]) shader_wrappers = [ - self.back_stroke_shader_wrapper, - self.fill_shader_wrapper, - self.stroke_shader_wrapper + self.back_stroke_shader_wrapper.read_in(back_stroke_data), + self.fill_shader_wrapper.read_in(fill_datas, fill_indices), + self.stroke_shader_wrapper.read_in(stroke_datas), ] + for sw in shader_wrappers: - # TODO, handle depth test and uniforms... - pass + # Assume uniforms of the first family member + sw.uniforms = family[0].get_uniforms() + sw.depth_test = family[0].depth_test return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] - def get_stroke_shader_data(self) -> np.ndarray: - # Set data array to be one longer than number of points, - # with a dummy vertex added at the end. This is to ensure - # it can be safely stacked onto other stroke data arrays. - n = len(self.data) - size = n + 1 if n > 0 else 0 - self.stroke_data = resize_array(self.stroke_data, size) - if n == 0: - return self.stroke_data - - self.get_joint_angles() # Recomputes, only if refresh is needed - - for key in self.stroke_data.dtype.names: - self.stroke_data[key][:n] = self.data[key] - self.stroke_data[-1] = self.stroke_data[-2] - return self.stroke_data - - def get_fill_shader_data(self) -> np.ndarray: - self.fill_data = resize_array(self.fill_data, len(self.data)) - for key in self.fill_data.dtype.names: - self.fill_data[key][:] = self.data[key] - return self.fill_data - def refresh_shader_data(self): - self.get_fill_shader_data() - self.get_stroke_shader_data() - - def get_fill_shader_vert_indices(self) -> np.ndarray: - return self.get_triangulation() + for submob in self.get_family(): + submob.get_joint_angles() + self.get_shader_wrapper_list() class VGroup(VMobject):