Unify get_shader_wrapper_list, and and better subdivide render groups by ShaderWrapper ids

This commit is contained in:
Grant Sanderson 2024-08-20 08:53:51 -05:00
parent 08e33faab8
commit f12b143d16
4 changed files with 50 additions and 65 deletions

View file

@ -1944,8 +1944,9 @@ class Mobject(object):
@affects_data
def replace_shader_code(self, old: str, new: str) -> Self:
self.shader_code_replacements[old] = new
self._shaders_initialized = False
for mob in self.get_family():
mob.shader_code_replacements[old] = new
mob._shaders_initialized = False
for mob in self.get_ancestors():
mob._shaders_initialized = False
return self
@ -2000,48 +2001,45 @@ class Mobject(object):
texture_paths=self.texture_paths,
depth_test=self.depth_test,
render_primitive=self.render_primitive,
code_replacements=self.shader_code_replacements,
)
def refresh_shader_wrapper_id(self):
if self._shaders_initialized:
self.shader_wrapper.refresh_id()
for submob in self.get_family():
if submob._shaders_initialized:
submob.shader_wrapper.depth_test = submob.depth_test
submob.shader_wrapper.refresh_id()
return self
def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
if not self._shaders_initialized:
self.init_shader_data(ctx)
self._shaders_initialized = True
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test
for old, new in self.shader_code_replacements.items():
self.shader_wrapper.replace_code(old, new)
return self.shader_wrapper
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
family = self.family_members_with_points()
for submob in family:
submob.get_shader_wrapper(ctx)
batches = batch_by_property(family, lambda submob: submob.shader_wrapper.get_id())
batches = batch_by_property(family, lambda sm: sm.get_shader_wrapper(ctx).get_id())
result = []
for submobs, sid in batches:
shader_wrapper = submobs[0].shader_wrapper
data_list = [sm.get_shader_data() for sm in submobs]
indices_list = [sm.get_shader_vert_indices() for sm in submobs]
if indices_list[0] is None:
indices_list = None
shader_wrapper.read_in(data_list, indices_list)
data_list = list(it.chain(*(sm.get_shader_data() for sm in submobs)))
shader_wrapper.read_in(data_list, indices_list=None)
result.append(shader_wrapper)
return result
def get_shader_data(self):
return self.data
def get_shader_data(self) -> Iterable[np.ndarray]:
indices = self.get_shader_vert_indices()
if indices is not None:
return [self.data[indices]]
else:
return [self.data]
def get_uniforms(self):
return self.uniforms
def get_shader_vert_indices(self):
def get_shader_vert_indices(self) -> Optional[np.ndarray]:
return self.shader_indices
def render(self, ctx: Context, camera_uniforms: dict):

View file

@ -191,6 +191,7 @@ class VMobject(Mobject):
if background is not None:
for mob in self.get_family(recurse):
mob.stroke_behind = background
mob.refresh_shader_wrapper_id()
if flat is not None:
self.set_flat_stroke(flat)
@ -1285,44 +1286,27 @@ class VMobject(Mobject):
# For shaders
def init_shader_data(self, ctx: Context):
self.shader_indices = np.zeros(0)
self.shader_indices = None
self.shader_wrapper = VShaderWrapper(
ctx=ctx,
vert_data=self.data,
mobject_uniforms=self.uniforms,
code_replacements=self.shader_code_replacements,
stroke_behind=self.stroke_behind,
depth_test=self.depth_test
)
def get_shader_vert_indices(self):
return None
def refresh_shader_wrapper_id(self):
for submob in self.get_family():
if submob._shaders_initialized:
submob.shader_wrapper.stroke_behind = submob.stroke_behind
super().refresh_shader_wrapper_id()
return self
def get_shader_data(self):
# This should only come up when VMobjects appear together in a group
return np.hstack([self.data, self.data[-1:]])
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized:
self.init_shader_data(ctx)
self._shaders_initialized = True
family = self.family_members_with_points()
if not family:
return []
stroke_behind = False
for submob in family:
# Maybe do this on set points instead? Or on noting changed data?
submob.data["base_normal"][0::2] = submob.data["point"][0]
if submob.stroke_behind:
stroke_behind = True
self.shader_wrapper.read_in(
list(it.chain(*([sm.data, sm.data[-1:]] for sm in family)))
)
rep = family[0] # Representative family member
self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms())
self.shader_wrapper.depth_test = rep.depth_test
self.shader_wrapper.stroke_behind = stroke_behind
return [self.shader_wrapper]
def get_shader_data(self) -> Iterable[np.ndarray]:
# Do we want this elsewhere? Say whenever points are refreshed or something?
self.data["base_normal"][0::2] = self.data["point"][0]
return [self.data, self.data[-1:]]
class VGroup(Group, VMobject, Generic[SubVmobjectType]):

View file

@ -388,7 +388,7 @@ class Scene(object):
"""
batches = batch_by_property(
self.mobjects,
lambda m: str(type(m)) + str(m.get_uniforms())
lambda m: m.get_shader_wrapper(self.camera.ctx).get_id()
)
for group in self.render_groups:

View file

@ -41,6 +41,7 @@ class ShaderWrapper(object):
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP,
code_replacements: dict[str, str] = dict(),
):
self.ctx = ctx
self.vert_data = vert_data
@ -48,13 +49,15 @@ class ShaderWrapper(object):
self.shader_folder = shader_folder
self.depth_test = depth_test
self.render_primitive = render_primitive
self.texture_names_to_ids = dict()
self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
self.init_program_code()
for old, new in code_replacements.items():
self.replace_code(old, new)
self.init_program()
self.texture_names_to_ids = dict()
if texture_paths is not None:
self.init_textures(texture_paths)
self.init_vertex_objects()
@ -98,19 +101,14 @@ class ShaderWrapper(object):
def get_id(self) -> str:
return self.id
def create_id(self) -> str:
# A unique id for a shader
program_id = hash("".join(map(str, self.program_code.values())))
return "|".join(map(str, [
program_id,
def refresh_id(self) -> None:
self.id = hash("".join(map(str, [
"".join(map(str, self.program_code.values())),
self.mobject_uniforms,
self.depth_test,
self.render_primitive,
self.texture_names_to_ids,
]))
def refresh_id(self) -> None:
self.id = self.create_id()
])))
def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code
@ -213,10 +211,12 @@ class VShaderWrapper(ShaderWrapper):
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
# render_primitive: int = moderngl.TRIANGLES,
render_primitive: int = moderngl.TRIANGLE_STRIP,
code_replacements: dict[str, str] = dict(),
stroke_behind: bool = False,
):
self.stroke_behind = stroke_behind
self.fill_canvas = get_fill_canvas(ctx)
super().__init__(
ctx=ctx,
vert_data=vert_data,
@ -225,9 +225,8 @@ class VShaderWrapper(ShaderWrapper):
texture_paths=texture_paths,
depth_test=depth_test,
render_primitive=render_primitive,
code_replacements=code_replacements,
)
self.stroke_behind = stroke_behind
self.fill_canvas = get_fill_canvas(self.ctx)
def init_program_code(self) -> None:
self.program_code = {
@ -312,6 +311,10 @@ class VShaderWrapper(ShaderWrapper):
def set_backstroke(self, value: bool = True):
self.stroke_behind = value
def refresh_id(self):
super().refresh_id()
self.id = hash(str(self.id) + str(self.stroke_behind))
# TODO, motidify read in to handle triangulation case for non-winding fill?
# Rendering