Change how ShaderWrapper uniforms are handled

This commit is contained in:
Grant Sanderson 2023-02-02 17:45:52 -08:00
parent 1dcc678b2f
commit 7f940fbee4
4 changed files with 49 additions and 40 deletions

View file

@ -1967,7 +1967,7 @@ class Mobject(object):
self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.update_program_uniforms(self.get_uniforms())
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test
return self.shader_wrapper
@ -2004,9 +2004,7 @@ class Mobject(object):
shader_wrapper.generate_vao()
self._data_has_changed = False
for shader_wrapper in self.shader_wrappers:
shader_wrapper.depth_test = self.depth_test
shader_wrapper.update_program_uniforms(self.get_uniforms())
shader_wrapper.update_program_uniforms(camera_uniforms, universal=True)
shader_wrapper.update_program_uniforms(camera_uniforms)
shader_wrapper.pre_render()
shader_wrapper.render()

View file

@ -1284,14 +1284,14 @@ class VMobject(Mobject):
self.fill_shader_wrapper = FillShaderWrapper(
ctx=ctx,
vert_data=fill_data,
uniforms=self.uniforms,
mobject_uniforms=self.uniforms,
shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive,
)
self.stroke_shader_wrapper = ShaderWrapper(
ctx=ctx,
vert_data=stroke_data,
uniforms=self.uniforms,
mobject_uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive,
)
@ -1309,11 +1309,6 @@ class VMobject(Mobject):
wrapper.refresh_id()
return self
def get_uniforms(self):
# TODO, account for submob uniforms separately?
self.uniforms.update(self.family_members_with_points()[0].uniforms)
return self.uniforms
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized:
self.init_shader_data(ctx)
@ -1325,32 +1320,25 @@ class VMobject(Mobject):
fill_names = self.fill_data_names
stroke_names = self.stroke_data_names
# Build up data lists
fill_family = (sm for sm in family if sm._has_fill)
stroke_family = (sm for sm in family if sm._has_stroke)
# Build up fill data lists
fill_datas = []
fill_indices = []
fill_border_datas = []
stroke_datas = []
back_stroke_datas = []
for submob in family:
submob.get_joint_products()
for submob in fill_family:
indices = submob.get_outer_vert_indices()
has_fill = submob._has_fill
has_stroke = submob._has_stroke
back_stroke = has_stroke and submob.stroke_behind
front_stroke = has_stroke and not submob.stroke_behind
if back_stroke:
back_stroke_datas.append(submob.data[stroke_names][indices])
if front_stroke:
stroke_datas.append(submob.data[stroke_names][indices])
if has_fill and submob._use_winding_fill:
if submob._use_winding_fill:
data = submob.data[fill_names]
data["base_point"][:] = data["point"][0]
fill_datas.append(data[indices])
if has_fill and not submob._use_winding_fill:
else:
fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation())
if has_fill and not front_stroke:
if (not submob._has_stroke) or submob.stroke_behind:
# Add fill border
submob.get_joint_products()
names = list(stroke_names)
names[names.index('stroke_rgba')] = 'fill_rgba'
names[names.index('stroke_width')] = 'fill_border_width'
@ -1359,11 +1347,25 @@ class VMobject(Mobject):
)
fill_border_datas.append(border_stroke_data[indices])
# Build up stroke data lists
stroke_datas = []
back_stroke_datas = []
for submob in stroke_family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices()
if submob.stroke_behind:
back_stroke_datas.append(submob.data[stroke_names][indices])
else:
stroke_datas.append(submob.data[stroke_names][indices])
shader_wrappers = [
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
self.stroke_shader_wrapper.read_in(stroke_datas),
]
for sw in shader_wrappers:
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
sw.depth_test = family[0].depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
@ -1371,6 +1373,8 @@ class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs):
super().__init__(**kwargs)
self.add(*vmobjects)
if vmobjects:
self.uniforms.update(vmobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject))

View file

@ -387,7 +387,10 @@ class Scene(object):
same type are grouped together, so this function creates
Groups of all clusters of adjacent Mobjects in the scene
"""
batches = batch_by_property(self.mobjects, lambda m: str(type(m)))
batches = batch_by_property(
self.mobjects,
lambda m: str(type(m)) + str(m.get_uniforms())
)
for group in self.render_groups:
group.clear()

View file

@ -37,7 +37,7 @@ class ShaderWrapper(object):
vert_data: np.ndarray,
vert_indices: Optional[np.ndarray] = None,
shader_folder: Optional[str] = None,
uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP,
@ -47,13 +47,14 @@ class ShaderWrapper(object):
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder
self.uniforms: UniformDict = dict()
self.depth_test = depth_test
self.render_primitive = render_primitive
self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms)
self.init_program_code()
self.init_program()
self.update_program_uniforms(uniforms or dict())
if texture_paths is not None:
self.init_textures(texture_paths)
self.init_vao()
@ -91,14 +92,17 @@ class ShaderWrapper(object):
self.ibo = None
self.vao = None
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
self.mobject_uniforms = mobject_uniforms
def __eq__(self, shader_wrapper: ShaderWrapper):
return all((
np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices),
self.shader_folder == shader_wrapper.shader_folder,
all(
self.uniforms[key] == shader_wrapper.uniforms[key]
for key in self.uniforms
self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
for key in self.mobject_uniforms
),
self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive,
@ -129,7 +133,7 @@ class ShaderWrapper(object):
# A unique id for a shader
return "|".join(map(str, [
self.program_id,
self.uniforms,
self.mobject_uniforms,
self.depth_test,
self.render_primitive,
]))
@ -155,9 +159,9 @@ class ShaderWrapper(object):
# Changing context
def use_clip_plane(self):
if "clip_plane" not in self.uniforms:
if "clip_plane" not in self.mobject_uniforms:
return False
return any(self.uniforms["clip_plane"])
return any(self.mobject_uniforms["clip_plane"])
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
@ -222,18 +226,18 @@ class ShaderWrapper(object):
assert(self.vao is not None)
self.vao.render()
def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False):
def update_program_uniforms(self, camera_uniforms: UniformDict):
if self.program is None:
return
for name, value in uniforms.items():
for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
if name not in self.program:
continue
if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value)
if universal and self.uniforms.get(name, None) == value:
if name in camera_uniforms and self.program_uniform_mirror.get(name, None) == value:
continue
self.program[name].value = value
self.uniforms[name] = value
self.program_uniform_mirror[name] = value
def get_vertex_buffer_object(self, refresh: bool = True):
if refresh: