From ba9f61b50ba162ce5629f3bbca286763f2cd1f7d Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sun, 15 Jan 2023 20:27:19 -0800 Subject: [PATCH] Have ShaderWrapper read in data rather than other shader wrappers --- manimlib/mobject/types/vectorized_mobject.py | 50 +++++++++++--------- manimlib/shader_wrapper.py | 37 ++++++++++----- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 6bfd8101..106d6284 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1212,37 +1212,40 @@ class VMobject(Mobject): def get_shader_wrapper_list(self) -> list[ShaderWrapper]: # Build up data lists - fill_sws = [] - stroke_sws = [] - bstroke_sws = [] + fill_submobs = [] + stroke_submobs = [] + bstroke_submobs = [] for submob in self.family_members_with_points(): if submob.has_fill(): - fill_sws.append(submob.get_fill_shader_wrapper()) + fill_submobs.append(submob) if submob.has_stroke(): - lst = bstroke_sws if submob.draw_stroke_behind_fill else stroke_sws - lst.append(submob.get_stroke_shader_wrapper()) + if submob.draw_stroke_behind_fill: + bstroke_submobs.append(submob) + else: + stroke_submobs.append(submob) - self_sws = [ + fill_names = list(self.fill_data.dtype.names) + self.fill_shader_wrapper.read_in( + [sm.data[fill_names] for sm in fill_submobs], + [sm.get_fill_shader_vert_indices() for sm in fill_submobs], + ) + self.stroke_shader_wrapper.read_in( + [sm.get_stroke_shader_data() for sm in stroke_submobs], + ) + self.back_stroke_shader_wrapper.read_in( + [sm.get_stroke_shader_data() for sm in bstroke_submobs], + ) + + + shader_wrappers = [ self.back_stroke_shader_wrapper, self.fill_shader_wrapper, self.stroke_shader_wrapper ] - sw_lists = [ - bstroke_sws, - fill_sws, - stroke_sws - ] - for sw, sw_list in zip(self_sws, sw_lists): - if not sw_list: - sw.vert_data = resize_array(sw.vert_data, 0) - continue - if sw is sw_list[0]: - sw.combine_with(*sw_list[1:]) - else: - sw.read_in(*sw_list) - sw.depth_test = any(sw.depth_test for sw in sw_list) - sw.uniforms.update(sw_list[0].uniforms) - return [sw for sw in self_sws if len(sw.vert_data) > 0] + for sw in shader_wrappers: + # TODO, handle depth test and uniforms... + pass + return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] def get_stroke_shader_data(self) -> np.ndarray: # Set data array to be one longer than number of points, @@ -1255,6 +1258,7 @@ class VMobject(Mobject): return self.stroke_data self.get_joint_angles() # Recomputes, only if refresh is needed + for key in self.stroke_data.dtype.names: self.stroke_data[key][:n] = self.data[key] self.stroke_data[-1] = self.stroke_data[-2] diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index f7fb95d3..ba476d2b 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -16,7 +16,7 @@ from manimlib.utils.iterables import resize_array from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Iterable + from typing import Iterable, List # Mobjects that should be rendered with @@ -136,25 +136,38 @@ class ShaderWrapper(object): def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: if len(shader_wrappers) > 0: - self.read_in(self.copy(), *shader_wrappers) + data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)] + if self.vert_indices is not None: + indices_list = [self.vert_indices, *(sw.vert_indices for sw in shader_wrappers)] + else: + indices_list = None + self.read_in(data_list, indices_list) return self - def read_in(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: + def read_in( + self, + vert_data_list: List[np.ndarray], + vert_indices_list: List[np.ndarray] | None = None + ) -> ShaderWrapper: # Assume all are of the same type - total_len = sum(len(sw.vert_data) for sw in shader_wrappers) + total_len = sum(len(data) for data in vert_data_list) self.vert_data = resize_array(self.vert_data, total_len) - if self.vert_indices is not None: - total_verts = sum(len(sw.vert_indices) for sw in shader_wrappers) + if total_len == 0: + return self + + if vert_indices_list is not None and self.vert_indices is not None: + total_verts = sum(len(vi) for vi in vert_indices_list) self.vert_indices = resize_array(self.vert_indices, total_verts) n_points = 0 n_verts = 0 - for sw in shader_wrappers: - new_n_points = n_points + len(sw.vert_data) - self.vert_data[n_points:new_n_points] = sw.vert_data - if self.vert_indices is not None and sw.vert_indices is not None: - new_n_verts = n_verts + len(sw.vert_indices) - self.vert_indices[n_verts:new_n_verts] = sw.vert_indices + n_points + for k, data in enumerate(vert_data_list): + new_n_points = n_points + len(data) + self.vert_data[n_points:new_n_points] = data + if self.vert_indices is not None and vert_indices_list is not None: + vert_indices = vert_indices_list[k] + new_n_verts = n_verts + len(vert_indices) + self.vert_indices[n_verts:new_n_verts] = vert_indices + n_points n_verts = new_n_verts n_points = new_n_points return self