mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Change how ShaderWrapper uniforms are handled
This commit is contained in:
parent
1dcc678b2f
commit
7f940fbee4
4 changed files with 49 additions and 40 deletions
|
@ -1967,7 +1967,7 @@ class Mobject(object):
|
|||
|
||||
self.shader_wrapper.vert_data = self.get_shader_data()
|
||||
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
|
||||
self.shader_wrapper.update_program_uniforms(self.get_uniforms())
|
||||
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
|
||||
self.shader_wrapper.depth_test = self.depth_test
|
||||
return self.shader_wrapper
|
||||
|
||||
|
@ -2004,9 +2004,7 @@ class Mobject(object):
|
|||
shader_wrapper.generate_vao()
|
||||
self._data_has_changed = False
|
||||
for shader_wrapper in self.shader_wrappers:
|
||||
shader_wrapper.depth_test = self.depth_test
|
||||
shader_wrapper.update_program_uniforms(self.get_uniforms())
|
||||
shader_wrapper.update_program_uniforms(camera_uniforms, universal=True)
|
||||
shader_wrapper.update_program_uniforms(camera_uniforms)
|
||||
shader_wrapper.pre_render()
|
||||
shader_wrapper.render()
|
||||
|
||||
|
|
|
@ -1284,14 +1284,14 @@ class VMobject(Mobject):
|
|||
self.fill_shader_wrapper = FillShaderWrapper(
|
||||
ctx=ctx,
|
||||
vert_data=fill_data,
|
||||
uniforms=self.uniforms,
|
||||
mobject_uniforms=self.uniforms,
|
||||
shader_folder=self.fill_shader_folder,
|
||||
render_primitive=self.fill_render_primitive,
|
||||
)
|
||||
self.stroke_shader_wrapper = ShaderWrapper(
|
||||
ctx=ctx,
|
||||
vert_data=stroke_data,
|
||||
uniforms=self.uniforms,
|
||||
mobject_uniforms=self.uniforms,
|
||||
shader_folder=self.stroke_shader_folder,
|
||||
render_primitive=self.stroke_render_primitive,
|
||||
)
|
||||
|
@ -1309,11 +1309,6 @@ class VMobject(Mobject):
|
|||
wrapper.refresh_id()
|
||||
return self
|
||||
|
||||
def get_uniforms(self):
|
||||
# TODO, account for submob uniforms separately?
|
||||
self.uniforms.update(self.family_members_with_points()[0].uniforms)
|
||||
return self.uniforms
|
||||
|
||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
||||
if not self._shaders_initialized:
|
||||
self.init_shader_data(ctx)
|
||||
|
@ -1325,32 +1320,25 @@ class VMobject(Mobject):
|
|||
fill_names = self.fill_data_names
|
||||
stroke_names = self.stroke_data_names
|
||||
|
||||
# Build up data lists
|
||||
fill_family = (sm for sm in family if sm._has_fill)
|
||||
stroke_family = (sm for sm in family if sm._has_stroke)
|
||||
|
||||
# Build up fill data lists
|
||||
fill_datas = []
|
||||
fill_indices = []
|
||||
fill_border_datas = []
|
||||
stroke_datas = []
|
||||
back_stroke_datas = []
|
||||
for submob in family:
|
||||
submob.get_joint_products()
|
||||
for submob in fill_family:
|
||||
indices = submob.get_outer_vert_indices()
|
||||
has_fill = submob._has_fill
|
||||
has_stroke = submob._has_stroke
|
||||
back_stroke = has_stroke and submob.stroke_behind
|
||||
front_stroke = has_stroke and not submob.stroke_behind
|
||||
if back_stroke:
|
||||
back_stroke_datas.append(submob.data[stroke_names][indices])
|
||||
if front_stroke:
|
||||
stroke_datas.append(submob.data[stroke_names][indices])
|
||||
if has_fill and submob._use_winding_fill:
|
||||
if submob._use_winding_fill:
|
||||
data = submob.data[fill_names]
|
||||
data["base_point"][:] = data["point"][0]
|
||||
fill_datas.append(data[indices])
|
||||
if has_fill and not submob._use_winding_fill:
|
||||
else:
|
||||
fill_datas.append(submob.data[fill_names])
|
||||
fill_indices.append(submob.get_triangulation())
|
||||
if has_fill and not front_stroke:
|
||||
if (not submob._has_stroke) or submob.stroke_behind:
|
||||
# Add fill border
|
||||
submob.get_joint_products()
|
||||
names = list(stroke_names)
|
||||
names[names.index('stroke_rgba')] = 'fill_rgba'
|
||||
names[names.index('stroke_width')] = 'fill_border_width'
|
||||
|
@ -1359,11 +1347,25 @@ class VMobject(Mobject):
|
|||
)
|
||||
fill_border_datas.append(border_stroke_data[indices])
|
||||
|
||||
# Build up stroke data lists
|
||||
stroke_datas = []
|
||||
back_stroke_datas = []
|
||||
for submob in stroke_family:
|
||||
submob.get_joint_products()
|
||||
indices = submob.get_outer_vert_indices()
|
||||
if submob.stroke_behind:
|
||||
back_stroke_datas.append(submob.data[stroke_names][indices])
|
||||
else:
|
||||
stroke_datas.append(submob.data[stroke_names][indices])
|
||||
|
||||
shader_wrappers = [
|
||||
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
|
||||
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
|
||||
self.stroke_shader_wrapper.read_in(stroke_datas),
|
||||
]
|
||||
for sw in shader_wrappers:
|
||||
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
|
||||
sw.depth_test = family[0].depth_test
|
||||
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
|
||||
|
||||
|
||||
|
@ -1371,6 +1373,8 @@ class VGroup(VMobject):
|
|||
def __init__(self, *vmobjects: VMobject, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add(*vmobjects)
|
||||
if vmobjects:
|
||||
self.uniforms.update(vmobjects[0].uniforms)
|
||||
|
||||
def __add__(self, other: VMobject) -> Self:
|
||||
assert(isinstance(other, VMobject))
|
||||
|
|
|
@ -387,7 +387,10 @@ class Scene(object):
|
|||
same type are grouped together, so this function creates
|
||||
Groups of all clusters of adjacent Mobjects in the scene
|
||||
"""
|
||||
batches = batch_by_property(self.mobjects, lambda m: str(type(m)))
|
||||
batches = batch_by_property(
|
||||
self.mobjects,
|
||||
lambda m: str(type(m)) + str(m.get_uniforms())
|
||||
)
|
||||
|
||||
for group in self.render_groups:
|
||||
group.clear()
|
||||
|
|
|
@ -37,7 +37,7 @@ class ShaderWrapper(object):
|
|||
vert_data: np.ndarray,
|
||||
vert_indices: Optional[np.ndarray] = None,
|
||||
shader_folder: Optional[str] = None,
|
||||
uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
||||
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
||||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||
depth_test: bool = False,
|
||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||
|
@ -47,13 +47,14 @@ class ShaderWrapper(object):
|
|||
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
|
||||
self.vert_attributes = vert_data.dtype.names
|
||||
self.shader_folder = shader_folder
|
||||
self.uniforms: UniformDict = dict()
|
||||
self.depth_test = depth_test
|
||||
self.render_primitive = render_primitive
|
||||
|
||||
self.program_uniform_mirror: UniformDict = dict()
|
||||
self.bind_to_mobject_uniforms(mobject_uniforms)
|
||||
|
||||
self.init_program_code()
|
||||
self.init_program()
|
||||
self.update_program_uniforms(uniforms or dict())
|
||||
if texture_paths is not None:
|
||||
self.init_textures(texture_paths)
|
||||
self.init_vao()
|
||||
|
@ -91,14 +92,17 @@ class ShaderWrapper(object):
|
|||
self.ibo = None
|
||||
self.vao = None
|
||||
|
||||
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
|
||||
self.mobject_uniforms = mobject_uniforms
|
||||
|
||||
def __eq__(self, shader_wrapper: ShaderWrapper):
|
||||
return all((
|
||||
np.all(self.vert_data == shader_wrapper.vert_data),
|
||||
np.all(self.vert_indices == shader_wrapper.vert_indices),
|
||||
self.shader_folder == shader_wrapper.shader_folder,
|
||||
all(
|
||||
self.uniforms[key] == shader_wrapper.uniforms[key]
|
||||
for key in self.uniforms
|
||||
self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
|
||||
for key in self.mobject_uniforms
|
||||
),
|
||||
self.depth_test == shader_wrapper.depth_test,
|
||||
self.render_primitive == shader_wrapper.render_primitive,
|
||||
|
@ -129,7 +133,7 @@ class ShaderWrapper(object):
|
|||
# A unique id for a shader
|
||||
return "|".join(map(str, [
|
||||
self.program_id,
|
||||
self.uniforms,
|
||||
self.mobject_uniforms,
|
||||
self.depth_test,
|
||||
self.render_primitive,
|
||||
]))
|
||||
|
@ -155,9 +159,9 @@ class ShaderWrapper(object):
|
|||
|
||||
# Changing context
|
||||
def use_clip_plane(self):
|
||||
if "clip_plane" not in self.uniforms:
|
||||
if "clip_plane" not in self.mobject_uniforms:
|
||||
return False
|
||||
return any(self.uniforms["clip_plane"])
|
||||
return any(self.mobject_uniforms["clip_plane"])
|
||||
|
||||
def set_ctx_depth_test(self, enable: bool = True) -> None:
|
||||
if enable:
|
||||
|
@ -222,18 +226,18 @@ class ShaderWrapper(object):
|
|||
assert(self.vao is not None)
|
||||
self.vao.render()
|
||||
|
||||
def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False):
|
||||
def update_program_uniforms(self, camera_uniforms: UniformDict):
|
||||
if self.program is None:
|
||||
return
|
||||
for name, value in uniforms.items():
|
||||
for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
|
||||
if name not in self.program:
|
||||
continue
|
||||
if isinstance(value, np.ndarray) and value.ndim > 0:
|
||||
value = tuple(value)
|
||||
if universal and self.uniforms.get(name, None) == value:
|
||||
if name in camera_uniforms and self.program_uniform_mirror.get(name, None) == value:
|
||||
continue
|
||||
self.program[name].value = value
|
||||
self.uniforms[name] = value
|
||||
self.program_uniform_mirror[name] = value
|
||||
|
||||
def get_vertex_buffer_object(self, refresh: bool = True):
|
||||
if refresh:
|
||||
|
|
Loading…
Add table
Reference in a new issue