mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Change how ShaderWrapper uniforms are handled
This commit is contained in:
parent
1dcc678b2f
commit
7f940fbee4
4 changed files with 49 additions and 40 deletions
|
@ -1967,7 +1967,7 @@ class Mobject(object):
|
||||||
|
|
||||||
self.shader_wrapper.vert_data = self.get_shader_data()
|
self.shader_wrapper.vert_data = self.get_shader_data()
|
||||||
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
|
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
|
||||||
self.shader_wrapper.update_program_uniforms(self.get_uniforms())
|
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
|
||||||
self.shader_wrapper.depth_test = self.depth_test
|
self.shader_wrapper.depth_test = self.depth_test
|
||||||
return self.shader_wrapper
|
return self.shader_wrapper
|
||||||
|
|
||||||
|
@ -2004,9 +2004,7 @@ class Mobject(object):
|
||||||
shader_wrapper.generate_vao()
|
shader_wrapper.generate_vao()
|
||||||
self._data_has_changed = False
|
self._data_has_changed = False
|
||||||
for shader_wrapper in self.shader_wrappers:
|
for shader_wrapper in self.shader_wrappers:
|
||||||
shader_wrapper.depth_test = self.depth_test
|
shader_wrapper.update_program_uniforms(camera_uniforms)
|
||||||
shader_wrapper.update_program_uniforms(self.get_uniforms())
|
|
||||||
shader_wrapper.update_program_uniforms(camera_uniforms, universal=True)
|
|
||||||
shader_wrapper.pre_render()
|
shader_wrapper.pre_render()
|
||||||
shader_wrapper.render()
|
shader_wrapper.render()
|
||||||
|
|
||||||
|
|
|
@ -1284,14 +1284,14 @@ class VMobject(Mobject):
|
||||||
self.fill_shader_wrapper = FillShaderWrapper(
|
self.fill_shader_wrapper = FillShaderWrapper(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
vert_data=fill_data,
|
vert_data=fill_data,
|
||||||
uniforms=self.uniforms,
|
mobject_uniforms=self.uniforms,
|
||||||
shader_folder=self.fill_shader_folder,
|
shader_folder=self.fill_shader_folder,
|
||||||
render_primitive=self.fill_render_primitive,
|
render_primitive=self.fill_render_primitive,
|
||||||
)
|
)
|
||||||
self.stroke_shader_wrapper = ShaderWrapper(
|
self.stroke_shader_wrapper = ShaderWrapper(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
vert_data=stroke_data,
|
vert_data=stroke_data,
|
||||||
uniforms=self.uniforms,
|
mobject_uniforms=self.uniforms,
|
||||||
shader_folder=self.stroke_shader_folder,
|
shader_folder=self.stroke_shader_folder,
|
||||||
render_primitive=self.stroke_render_primitive,
|
render_primitive=self.stroke_render_primitive,
|
||||||
)
|
)
|
||||||
|
@ -1309,11 +1309,6 @@ class VMobject(Mobject):
|
||||||
wrapper.refresh_id()
|
wrapper.refresh_id()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_uniforms(self):
|
|
||||||
# TODO, account for submob uniforms separately?
|
|
||||||
self.uniforms.update(self.family_members_with_points()[0].uniforms)
|
|
||||||
return self.uniforms
|
|
||||||
|
|
||||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
||||||
if not self._shaders_initialized:
|
if not self._shaders_initialized:
|
||||||
self.init_shader_data(ctx)
|
self.init_shader_data(ctx)
|
||||||
|
@ -1325,32 +1320,25 @@ class VMobject(Mobject):
|
||||||
fill_names = self.fill_data_names
|
fill_names = self.fill_data_names
|
||||||
stroke_names = self.stroke_data_names
|
stroke_names = self.stroke_data_names
|
||||||
|
|
||||||
# Build up data lists
|
fill_family = (sm for sm in family if sm._has_fill)
|
||||||
|
stroke_family = (sm for sm in family if sm._has_stroke)
|
||||||
|
|
||||||
|
# Build up fill data lists
|
||||||
fill_datas = []
|
fill_datas = []
|
||||||
fill_indices = []
|
fill_indices = []
|
||||||
fill_border_datas = []
|
fill_border_datas = []
|
||||||
stroke_datas = []
|
for submob in fill_family:
|
||||||
back_stroke_datas = []
|
|
||||||
for submob in family:
|
|
||||||
submob.get_joint_products()
|
|
||||||
indices = submob.get_outer_vert_indices()
|
indices = submob.get_outer_vert_indices()
|
||||||
has_fill = submob._has_fill
|
if submob._use_winding_fill:
|
||||||
has_stroke = submob._has_stroke
|
|
||||||
back_stroke = has_stroke and submob.stroke_behind
|
|
||||||
front_stroke = has_stroke and not submob.stroke_behind
|
|
||||||
if back_stroke:
|
|
||||||
back_stroke_datas.append(submob.data[stroke_names][indices])
|
|
||||||
if front_stroke:
|
|
||||||
stroke_datas.append(submob.data[stroke_names][indices])
|
|
||||||
if has_fill and submob._use_winding_fill:
|
|
||||||
data = submob.data[fill_names]
|
data = submob.data[fill_names]
|
||||||
data["base_point"][:] = data["point"][0]
|
data["base_point"][:] = data["point"][0]
|
||||||
fill_datas.append(data[indices])
|
fill_datas.append(data[indices])
|
||||||
if has_fill and not submob._use_winding_fill:
|
else:
|
||||||
fill_datas.append(submob.data[fill_names])
|
fill_datas.append(submob.data[fill_names])
|
||||||
fill_indices.append(submob.get_triangulation())
|
fill_indices.append(submob.get_triangulation())
|
||||||
if has_fill and not front_stroke:
|
if (not submob._has_stroke) or submob.stroke_behind:
|
||||||
# Add fill border
|
# Add fill border
|
||||||
|
submob.get_joint_products()
|
||||||
names = list(stroke_names)
|
names = list(stroke_names)
|
||||||
names[names.index('stroke_rgba')] = 'fill_rgba'
|
names[names.index('stroke_rgba')] = 'fill_rgba'
|
||||||
names[names.index('stroke_width')] = 'fill_border_width'
|
names[names.index('stroke_width')] = 'fill_border_width'
|
||||||
|
@ -1359,11 +1347,25 @@ class VMobject(Mobject):
|
||||||
)
|
)
|
||||||
fill_border_datas.append(border_stroke_data[indices])
|
fill_border_datas.append(border_stroke_data[indices])
|
||||||
|
|
||||||
|
# Build up stroke data lists
|
||||||
|
stroke_datas = []
|
||||||
|
back_stroke_datas = []
|
||||||
|
for submob in stroke_family:
|
||||||
|
submob.get_joint_products()
|
||||||
|
indices = submob.get_outer_vert_indices()
|
||||||
|
if submob.stroke_behind:
|
||||||
|
back_stroke_datas.append(submob.data[stroke_names][indices])
|
||||||
|
else:
|
||||||
|
stroke_datas.append(submob.data[stroke_names][indices])
|
||||||
|
|
||||||
shader_wrappers = [
|
shader_wrappers = [
|
||||||
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
|
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
|
||||||
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
|
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
|
||||||
self.stroke_shader_wrapper.read_in(stroke_datas),
|
self.stroke_shader_wrapper.read_in(stroke_datas),
|
||||||
]
|
]
|
||||||
|
for sw in shader_wrappers:
|
||||||
|
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
|
||||||
|
sw.depth_test = family[0].depth_test
|
||||||
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
|
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1371,6 +1373,8 @@ class VGroup(VMobject):
|
||||||
def __init__(self, *vmobjects: VMobject, **kwargs):
|
def __init__(self, *vmobjects: VMobject, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add(*vmobjects)
|
self.add(*vmobjects)
|
||||||
|
if vmobjects:
|
||||||
|
self.uniforms.update(vmobjects[0].uniforms)
|
||||||
|
|
||||||
def __add__(self, other: VMobject) -> Self:
|
def __add__(self, other: VMobject) -> Self:
|
||||||
assert(isinstance(other, VMobject))
|
assert(isinstance(other, VMobject))
|
||||||
|
|
|
@ -387,7 +387,10 @@ class Scene(object):
|
||||||
same type are grouped together, so this function creates
|
same type are grouped together, so this function creates
|
||||||
Groups of all clusters of adjacent Mobjects in the scene
|
Groups of all clusters of adjacent Mobjects in the scene
|
||||||
"""
|
"""
|
||||||
batches = batch_by_property(self.mobjects, lambda m: str(type(m)))
|
batches = batch_by_property(
|
||||||
|
self.mobjects,
|
||||||
|
lambda m: str(type(m)) + str(m.get_uniforms())
|
||||||
|
)
|
||||||
|
|
||||||
for group in self.render_groups:
|
for group in self.render_groups:
|
||||||
group.clear()
|
group.clear()
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ShaderWrapper(object):
|
||||||
vert_data: np.ndarray,
|
vert_data: np.ndarray,
|
||||||
vert_indices: Optional[np.ndarray] = None,
|
vert_indices: Optional[np.ndarray] = None,
|
||||||
shader_folder: Optional[str] = None,
|
shader_folder: Optional[str] = None,
|
||||||
uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
||||||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||||
depth_test: bool = False,
|
depth_test: bool = False,
|
||||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||||
|
@ -47,13 +47,14 @@ class ShaderWrapper(object):
|
||||||
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
|
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
|
||||||
self.vert_attributes = vert_data.dtype.names
|
self.vert_attributes = vert_data.dtype.names
|
||||||
self.shader_folder = shader_folder
|
self.shader_folder = shader_folder
|
||||||
self.uniforms: UniformDict = dict()
|
|
||||||
self.depth_test = depth_test
|
self.depth_test = depth_test
|
||||||
self.render_primitive = render_primitive
|
self.render_primitive = render_primitive
|
||||||
|
|
||||||
|
self.program_uniform_mirror: UniformDict = dict()
|
||||||
|
self.bind_to_mobject_uniforms(mobject_uniforms)
|
||||||
|
|
||||||
self.init_program_code()
|
self.init_program_code()
|
||||||
self.init_program()
|
self.init_program()
|
||||||
self.update_program_uniforms(uniforms or dict())
|
|
||||||
if texture_paths is not None:
|
if texture_paths is not None:
|
||||||
self.init_textures(texture_paths)
|
self.init_textures(texture_paths)
|
||||||
self.init_vao()
|
self.init_vao()
|
||||||
|
@ -91,14 +92,17 @@ class ShaderWrapper(object):
|
||||||
self.ibo = None
|
self.ibo = None
|
||||||
self.vao = None
|
self.vao = None
|
||||||
|
|
||||||
|
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
|
||||||
|
self.mobject_uniforms = mobject_uniforms
|
||||||
|
|
||||||
def __eq__(self, shader_wrapper: ShaderWrapper):
|
def __eq__(self, shader_wrapper: ShaderWrapper):
|
||||||
return all((
|
return all((
|
||||||
np.all(self.vert_data == shader_wrapper.vert_data),
|
np.all(self.vert_data == shader_wrapper.vert_data),
|
||||||
np.all(self.vert_indices == shader_wrapper.vert_indices),
|
np.all(self.vert_indices == shader_wrapper.vert_indices),
|
||||||
self.shader_folder == shader_wrapper.shader_folder,
|
self.shader_folder == shader_wrapper.shader_folder,
|
||||||
all(
|
all(
|
||||||
self.uniforms[key] == shader_wrapper.uniforms[key]
|
self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
|
||||||
for key in self.uniforms
|
for key in self.mobject_uniforms
|
||||||
),
|
),
|
||||||
self.depth_test == shader_wrapper.depth_test,
|
self.depth_test == shader_wrapper.depth_test,
|
||||||
self.render_primitive == shader_wrapper.render_primitive,
|
self.render_primitive == shader_wrapper.render_primitive,
|
||||||
|
@ -129,7 +133,7 @@ class ShaderWrapper(object):
|
||||||
# A unique id for a shader
|
# A unique id for a shader
|
||||||
return "|".join(map(str, [
|
return "|".join(map(str, [
|
||||||
self.program_id,
|
self.program_id,
|
||||||
self.uniforms,
|
self.mobject_uniforms,
|
||||||
self.depth_test,
|
self.depth_test,
|
||||||
self.render_primitive,
|
self.render_primitive,
|
||||||
]))
|
]))
|
||||||
|
@ -155,9 +159,9 @@ class ShaderWrapper(object):
|
||||||
|
|
||||||
# Changing context
|
# Changing context
|
||||||
def use_clip_plane(self):
|
def use_clip_plane(self):
|
||||||
if "clip_plane" not in self.uniforms:
|
if "clip_plane" not in self.mobject_uniforms:
|
||||||
return False
|
return False
|
||||||
return any(self.uniforms["clip_plane"])
|
return any(self.mobject_uniforms["clip_plane"])
|
||||||
|
|
||||||
def set_ctx_depth_test(self, enable: bool = True) -> None:
|
def set_ctx_depth_test(self, enable: bool = True) -> None:
|
||||||
if enable:
|
if enable:
|
||||||
|
@ -222,18 +226,18 @@ class ShaderWrapper(object):
|
||||||
assert(self.vao is not None)
|
assert(self.vao is not None)
|
||||||
self.vao.render()
|
self.vao.render()
|
||||||
|
|
||||||
def update_program_uniforms(self, uniforms: UniformDict, universal: bool = False):
|
def update_program_uniforms(self, camera_uniforms: UniformDict):
|
||||||
if self.program is None:
|
if self.program is None:
|
||||||
return
|
return
|
||||||
for name, value in uniforms.items():
|
for name, value in (*self.mobject_uniforms.items(), *camera_uniforms.items()):
|
||||||
if name not in self.program:
|
if name not in self.program:
|
||||||
continue
|
continue
|
||||||
if isinstance(value, np.ndarray) and value.ndim > 0:
|
if isinstance(value, np.ndarray) and value.ndim > 0:
|
||||||
value = tuple(value)
|
value = tuple(value)
|
||||||
if universal and self.uniforms.get(name, None) == value:
|
if name in camera_uniforms and self.program_uniform_mirror.get(name, None) == value:
|
||||||
continue
|
continue
|
||||||
self.program[name].value = value
|
self.program[name].value = value
|
||||||
self.uniforms[name] = value
|
self.program_uniform_mirror[name] = value
|
||||||
|
|
||||||
def get_vertex_buffer_object(self, refresh: bool = True):
|
def get_vertex_buffer_object(self, refresh: bool = True):
|
||||||
if refresh:
|
if refresh:
|
||||||
|
|
Loading…
Add table
Reference in a new issue