From 7f940fbee4729d50f8a8656759105c88d551b34d Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 2 Feb 2023 17:45:52 -0800 Subject: [PATCH] Change how ShaderWrapper uniforms are handled --- manimlib/mobject/mobject.py | 6 +-- manimlib/mobject/types/vectorized_mobject.py | 50 +++++++++++--------- manimlib/scene/scene.py | 5 +- manimlib/shader_wrapper.py | 28 ++++++----- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 34052b16..cf35b3b3 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1967,7 +1967,7 @@ class Mobject(object): self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_indices = self.get_shader_vert_indices() - self.shader_wrapper.update_program_uniforms(self.get_uniforms()) + self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms()) self.shader_wrapper.depth_test = self.depth_test return self.shader_wrapper @@ -2004,9 +2004,7 @@ class Mobject(object): shader_wrapper.generate_vao() self._data_has_changed = False for shader_wrapper in self.shader_wrappers: - shader_wrapper.depth_test = self.depth_test - shader_wrapper.update_program_uniforms(self.get_uniforms()) - shader_wrapper.update_program_uniforms(camera_uniforms, universal=True) + shader_wrapper.update_program_uniforms(camera_uniforms) shader_wrapper.pre_render() shader_wrapper.render() diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 461d0fe6..59bf2393 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1284,14 +1284,14 @@ class VMobject(Mobject): self.fill_shader_wrapper = FillShaderWrapper( ctx=ctx, vert_data=fill_data, - uniforms=self.uniforms, + mobject_uniforms=self.uniforms, shader_folder=self.fill_shader_folder, render_primitive=self.fill_render_primitive, ) self.stroke_shader_wrapper = ShaderWrapper( ctx=ctx, vert_data=stroke_data, - uniforms=self.uniforms, + mobject_uniforms=self.uniforms, shader_folder=self.stroke_shader_folder, render_primitive=self.stroke_render_primitive, ) @@ -1309,11 +1309,6 @@ class VMobject(Mobject): wrapper.refresh_id() return self - def get_uniforms(self): - # TODO, account for submob uniforms separately? - self.uniforms.update(self.family_members_with_points()[0].uniforms) - return self.uniforms - def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: if not self._shaders_initialized: self.init_shader_data(ctx) @@ -1325,32 +1320,25 @@ class VMobject(Mobject): fill_names = self.fill_data_names stroke_names = self.stroke_data_names - # Build up data lists + fill_family = (sm for sm in family if sm._has_fill) + stroke_family = (sm for sm in family if sm._has_stroke) + + # Build up fill data lists fill_datas = [] fill_indices = [] fill_border_datas = [] - stroke_datas = [] - back_stroke_datas = [] - for submob in family: - submob.get_joint_products() + for submob in fill_family: indices = submob.get_outer_vert_indices() - has_fill = submob._has_fill - has_stroke = submob._has_stroke - back_stroke = has_stroke and submob.stroke_behind - front_stroke = has_stroke and not submob.stroke_behind - if back_stroke: - back_stroke_datas.append(submob.data[stroke_names][indices]) - if front_stroke: - stroke_datas.append(submob.data[stroke_names][indices]) - if has_fill and submob._use_winding_fill: + if submob._use_winding_fill: data = submob.data[fill_names] data["base_point"][:] = data["point"][0] fill_datas.append(data[indices]) - if has_fill and not submob._use_winding_fill: + else: fill_datas.append(submob.data[fill_names]) fill_indices.append(submob.get_triangulation()) - if has_fill and not front_stroke: + if (not submob._has_stroke) or submob.stroke_behind: # Add fill border + submob.get_joint_products() names = list(stroke_names) names[names.index('stroke_rgba')] = 'fill_rgba' names[names.index('stroke_width')] = 'fill_border_width' @@ -1359,11 +1347,25 @@ class VMobject(Mobject): ) fill_border_datas.append(border_stroke_data[indices]) + # Build up stroke data lists + stroke_datas = [] + back_stroke_datas = [] + for submob in stroke_family: + submob.get_joint_products() + indices = submob.get_outer_vert_indices() + if submob.stroke_behind: + back_stroke_datas.append(submob.data[stroke_names][indices]) + else: + stroke_datas.append(submob.data[stroke_names][indices]) + shader_wrappers = [ self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]), self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None), self.stroke_shader_wrapper.read_in(stroke_datas), ] + for sw in shader_wrappers: + sw.bind_to_mobject_uniforms(family[0].get_uniforms()) + sw.depth_test = family[0].depth_test return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] @@ -1371,6 +1373,8 @@ class VGroup(VMobject): def __init__(self, *vmobjects: VMobject, **kwargs): super().__init__(**kwargs) self.add(*vmobjects) + if vmobjects: + self.uniforms.update(vmobjects[0].uniforms) def __add__(self, other: VMobject) -> Self: assert(isinstance(other, VMobject)) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 0a764737..d2af205a 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -387,7 +387,10 @@ class Scene(object): same type are grouped together, so this function creates Groups of all clusters of adjacent Mobjects in the scene """ - batches = batch_by_property(self.mobjects, lambda m: str(type(m))) + batches = batch_by_property( + self.mobjects, + lambda m: str(type(m)) + str(m.get_uniforms()) + ) for group in self.render_groups: group.clear() diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index dc5de477..454d78fe 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -37,7 +37,7 @@ class ShaderWrapper(object): vert_data: np.ndarray, vert_indices: Optional[np.ndarray] = None, shader_folder: Optional[str] = None, - uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables + mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. depth_test: bool = False, render_primitive: int = moderngl.TRIANGLE_STRIP, @@ -47,13 +47,14 @@ class ShaderWrapper(object): self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_attributes = vert_data.dtype.names self.shader_folder = shader_folder - self.uniforms: UniformDict = dict() self.depth_test = depth_test self.render_primitive = render_primitive + self.program_uniform_mirror: UniformDict = dict() + self.bind_to_mobject_uniforms(mobject_uniforms) + self.init_program_code() self.init_program() - self.update_program_uniforms(uniforms or dict()) if texture_paths is not None: self.init_textures(texture_paths) self.init_vao() @@ -91,14 +92,17 @@ class ShaderWrapper(object): self.ibo = None self.vao = None + def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict): + self.mobject_uniforms = mobject_uniforms + def __eq__(self, shader_wrapper: ShaderWrapper): return all(( np.all(self.vert_data == shader_wrapper.vert_data), np.all(self.vert_indices == shader_wrapper.vert_indices), self.shader_folder == shader_wrapper.shader_folder, all( - self.uniforms[key] == shader_wrapper.uniforms[key] - for key in self.uniforms + self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key] + for key in self.mobject_uniforms ), self.depth_test == shader_wrapper.depth_test, self.render_primitive == shader_wrapper.render_primitive, @@ -129,7 +133,7 @@ class ShaderWrapper(object): # A unique id for a shader return "|".join(map(str, [ self.program_id, - self.uniforms, + self.mobject_uniforms, self.depth_test, self.render_primitive, ])) @@ -155,9 +159,9 @@ class ShaderWrapper(object): # Changing context def use_clip_plane(self): - if "clip_plane" not in self.uniforms: + if "clip_plane" not in self.mobject_uniforms: return False - return any(self.uniforms["clip_plane"]) + return any(self.mobject_uniforms["clip_plane"]) def set_ctx_depth_test(self, enable: bool = True) -> None: if enable: @@ -222,18 +226,18 @@ class ShaderWrapper(object): assert(self.vao is not None) self.vao.render() - def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False): + def update_program_uniforms(self, camera_uniforms: UniformDict): if self.program is None: return - for name, value in uniforms.items(): + for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()): if name not in self.program: continue if isinstance(value, np.ndarray) and value.ndim > 0: value = tuple(value) - if universal and self.uniforms.get(name, None) == value: + if name in camera_uniforms and self.program_uniform_mirror.get(name, None) == value: continue self.program[name].value = value - self.uniforms[name] = value + self.program_uniform_mirror[name] = value def get_vertex_buffer_object(self, refresh: bool = True): if refresh: