Change how ShaderWrapper uniforms are handled

This commit is contained in:
Grant Sanderson 2023-02-02 17:45:52 -08:00
parent 1dcc678b2f
commit 7f940fbee4
4 changed files with 49 additions and 40 deletions

View file

@ -1967,7 +1967,7 @@ class Mobject(object):
self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.update_program_uniforms(self.get_uniforms()) self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test self.shader_wrapper.depth_test = self.depth_test
return self.shader_wrapper return self.shader_wrapper
@ -2004,9 +2004,7 @@ class Mobject(object):
shader_wrapper.generate_vao() shader_wrapper.generate_vao()
self._data_has_changed = False self._data_has_changed = False
for shader_wrapper in self.shader_wrappers: for shader_wrapper in self.shader_wrappers:
shader_wrapper.depth_test = self.depth_test shader_wrapper.update_program_uniforms(camera_uniforms)
shader_wrapper.update_program_uniforms(self.get_uniforms())
shader_wrapper.update_program_uniforms(camera_uniforms, universal=True)
shader_wrapper.pre_render() shader_wrapper.pre_render()
shader_wrapper.render() shader_wrapper.render()

View file

@ -1284,14 +1284,14 @@ class VMobject(Mobject):
self.fill_shader_wrapper = FillShaderWrapper( self.fill_shader_wrapper = FillShaderWrapper(
ctx=ctx, ctx=ctx,
vert_data=fill_data, vert_data=fill_data,
uniforms=self.uniforms, mobject_uniforms=self.uniforms,
shader_folder=self.fill_shader_folder, shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive, render_primitive=self.fill_render_primitive,
) )
self.stroke_shader_wrapper = ShaderWrapper( self.stroke_shader_wrapper = ShaderWrapper(
ctx=ctx, ctx=ctx,
vert_data=stroke_data, vert_data=stroke_data,
uniforms=self.uniforms, mobject_uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder, shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive, render_primitive=self.stroke_render_primitive,
) )
@ -1309,11 +1309,6 @@ class VMobject(Mobject):
wrapper.refresh_id() wrapper.refresh_id()
return self return self
def get_uniforms(self):
# TODO, account for submob uniforms separately?
self.uniforms.update(self.family_members_with_points()[0].uniforms)
return self.uniforms
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized: if not self._shaders_initialized:
self.init_shader_data(ctx) self.init_shader_data(ctx)
@ -1325,32 +1320,25 @@ class VMobject(Mobject):
fill_names = self.fill_data_names fill_names = self.fill_data_names
stroke_names = self.stroke_data_names stroke_names = self.stroke_data_names
# Build up data lists fill_family = (sm for sm in family if sm._has_fill)
stroke_family = (sm for sm in family if sm._has_stroke)
# Build up fill data lists
fill_datas = [] fill_datas = []
fill_indices = [] fill_indices = []
fill_border_datas = [] fill_border_datas = []
stroke_datas = [] for submob in fill_family:
back_stroke_datas = []
for submob in family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices() indices = submob.get_outer_vert_indices()
has_fill = submob._has_fill if submob._use_winding_fill:
has_stroke = submob._has_stroke
back_stroke = has_stroke and submob.stroke_behind
front_stroke = has_stroke and not submob.stroke_behind
if back_stroke:
back_stroke_datas.append(submob.data[stroke_names][indices])
if front_stroke:
stroke_datas.append(submob.data[stroke_names][indices])
if has_fill and submob._use_winding_fill:
data = submob.data[fill_names] data = submob.data[fill_names]
data["base_point"][:] = data["point"][0] data["base_point"][:] = data["point"][0]
fill_datas.append(data[indices]) fill_datas.append(data[indices])
if has_fill and not submob._use_winding_fill: else:
fill_datas.append(submob.data[fill_names]) fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation()) fill_indices.append(submob.get_triangulation())
if has_fill and not front_stroke: if (not submob._has_stroke) or submob.stroke_behind:
# Add fill border # Add fill border
submob.get_joint_products()
names = list(stroke_names) names = list(stroke_names)
names[names.index('stroke_rgba')] = 'fill_rgba' names[names.index('stroke_rgba')] = 'fill_rgba'
names[names.index('stroke_width')] = 'fill_border_width' names[names.index('stroke_width')] = 'fill_border_width'
@ -1359,11 +1347,25 @@ class VMobject(Mobject):
) )
fill_border_datas.append(border_stroke_data[indices]) fill_border_datas.append(border_stroke_data[indices])
# Build up stroke data lists
stroke_datas = []
back_stroke_datas = []
for submob in stroke_family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices()
if submob.stroke_behind:
back_stroke_datas.append(submob.data[stroke_names][indices])
else:
stroke_datas.append(submob.data[stroke_names][indices])
shader_wrappers = [ shader_wrappers = [
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]), self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None), self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
self.stroke_shader_wrapper.read_in(stroke_datas), self.stroke_shader_wrapper.read_in(stroke_datas),
] ]
for sw in shader_wrappers:
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
sw.depth_test = family[0].depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
@ -1371,6 +1373,8 @@ class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs): def __init__(self, *vmobjects: VMobject, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) self.add(*vmobjects)
if vmobjects:
self.uniforms.update(vmobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self: def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject)) assert(isinstance(other, VMobject))

View file

@ -387,7 +387,10 @@ class Scene(object):
same type are grouped together, so this function creates same type are grouped together, so this function creates
Groups of all clusters of adjacent Mobjects in the scene Groups of all clusters of adjacent Mobjects in the scene
""" """
batches = batch_by_property(self.mobjects, lambda m: str(type(m))) batches = batch_by_property(
self.mobjects,
lambda m: str(type(m)) + str(m.get_uniforms())
)
for group in self.render_groups: for group in self.render_groups:
group.clear() group.clear()

View file

@ -37,7 +37,7 @@ class ShaderWrapper(object):
vert_data: np.ndarray, vert_data: np.ndarray,
vert_indices: Optional[np.ndarray] = None, vert_indices: Optional[np.ndarray] = None,
shader_folder: Optional[str] = None, shader_folder: Optional[str] = None,
uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False, depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP, render_primitive: int = moderngl.TRIANGLE_STRIP,
@ -47,13 +47,14 @@ class ShaderWrapper(object):
self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder self.shader_folder = shader_folder
self.uniforms: UniformDict = dict()
self.depth_test = depth_test self.depth_test = depth_test
self.render_primitive = render_primitive self.render_primitive = render_primitive
self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms)
self.init_program_code() self.init_program_code()
self.init_program() self.init_program()
self.update_program_uniforms(uniforms or dict())
if texture_paths is not None: if texture_paths is not None:
self.init_textures(texture_paths) self.init_textures(texture_paths)
self.init_vao() self.init_vao()
@ -91,14 +92,17 @@ class ShaderWrapper(object):
self.ibo = None self.ibo = None
self.vao = None self.vao = None
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
self.mobject_uniforms = mobject_uniforms
def __eq__(self, shader_wrapper: ShaderWrapper): def __eq__(self, shader_wrapper: ShaderWrapper):
return all(( return all((
np.all(self.vert_data == shader_wrapper.vert_data), np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices), np.all(self.vert_indices == shader_wrapper.vert_indices),
self.shader_folder == shader_wrapper.shader_folder, self.shader_folder == shader_wrapper.shader_folder,
all( all(
self.uniforms[key] == shader_wrapper.uniforms[key] self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
for key in self.uniforms for key in self.mobject_uniforms
), ),
self.depth_test == shader_wrapper.depth_test, self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive, self.render_primitive == shader_wrapper.render_primitive,
@ -129,7 +133,7 @@ class ShaderWrapper(object):
# A unique id for a shader # A unique id for a shader
return "|".join(map(str, [ return "|".join(map(str, [
self.program_id, self.program_id,
self.uniforms, self.mobject_uniforms,
self.depth_test, self.depth_test,
self.render_primitive, self.render_primitive,
])) ]))
@ -155,9 +159,9 @@ class ShaderWrapper(object):
# Changing context # Changing context
def use_clip_plane(self): def use_clip_plane(self):
if "clip_plane" not in self.uniforms: if "clip_plane" not in self.mobject_uniforms:
return False return False
return any(self.uniforms["clip_plane"]) return any(self.mobject_uniforms["clip_plane"])
def set_ctx_depth_test(self, enable: bool = True) -> None: def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable: if enable:
@ -222,18 +226,18 @@ class ShaderWrapper(object):
assert(self.vao is not None) assert(self.vao is not None)
self.vao.render() self.vao.render()
def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False): def update_program_uniforms(self, camera_uniforms: UniformDict):
if self.program is None: if self.program is None:
return return
for name, value in uniforms.items(): for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
if name not in self.program: if name not in self.program:
continue continue
if isinstance(value, np.ndarray) and value.ndim > 0: if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value) value = tuple(value)
if universal and self.uniforms.get(name, None) == value: if name in camera_uniforms and self.program_uniform_mirror.get(name, None) == value:
continue continue
self.program[name].value = value self.program[name].value = value
self.uniforms[name] = value self.program_uniform_mirror[name] = value
def get_vertex_buffer_object(self, refresh: bool = True): def get_vertex_buffer_object(self, refresh: bool = True):
if refresh: if refresh: