Move rendering more fully away from Camera to Mobject and ShaderWrapper

This commit is contained in:
Grant Sanderson 2023-01-25 14:13:56 -08:00
parent 2c737ed540
commit 424707d035
4 changed files with 37 additions and 77 deletions

View file

@ -65,15 +65,10 @@ class Camera(object):
self.background_rgba: list[float] = list(color_to_rgba( self.background_rgba: list[float] = list(color_to_rgba(
background_color, background_opacity background_color, background_opacity
)) ))
self.perspective_uniforms = dict() self.uniforms = dict()
self.init_frame(**frame_config) self.init_frame(**frame_config)
self.init_context(window) self.init_context(window)
self.init_light_source() self.init_light_source()
self.refresh_perspective_uniforms()
# A cached map from mobjects to their associated list of render groups
# so that these render groups are not regenerated unnecessarily for static
# mobjects
self.mob_to_render_groups = {}
def init_frame(self, **config) -> None: def init_frame(self, **config) -> None:
self.frame = CameraFrame(**config) self.frame = CameraFrame(**config)
@ -212,60 +207,17 @@ class Camera(object):
# Rendering # Rendering
def capture(self, *mobjects: Mobject) -> None: def capture(self, *mobjects: Mobject) -> None:
self.refresh_perspective_uniforms() self.refresh_uniforms()
for mobject in mobjects: for mobject in mobjects:
for render_group in self.get_render_group_list(mobject): mobject.render(self.ctx, self.uniforms)
self.render(render_group)
def render(self, render_group: dict[str, Any]) -> None: def refresh_uniforms(self) -> None:
shader_wrapper = render_group["shader_wrapper"]
shader_wrapper.render(self.perspective_uniforms)
if render_group["single_use"]:
self.release_render_group(render_group)
def get_render_group_list(self, mobject: Mobject) -> Iterable[dict[str, Any]]:
if mobject.is_changing():
return self.generate_render_group_list(mobject)
# Otherwise, cache result for later use
key = id(mobject)
if key not in self.mob_to_render_groups:
self.mob_to_render_groups[key] = list(self.generate_render_group_list(mobject))
return self.mob_to_render_groups[key]
def generate_render_group_list(self, mobject: Mobject) -> Iterable[dict[str, Any]]:
return (
self.get_render_group(sw, single_use=mobject.is_changing())
for sw in mobject.get_shader_wrapper_list(self.ctx)
)
def get_render_group(
self,
shader_wrapper: ShaderWrapper,
single_use: bool = True
) -> dict[str, Any]:
shader_wrapper.get_vao()
return {
"shader_wrapper": shader_wrapper,
"single_use": single_use,
}
def release_render_group(self, render_group: dict[str, Any]) -> None:
render_group["shader_wrapper"].release()
def refresh_static_mobjects(self) -> None:
for render_group in it.chain(*self.mob_to_render_groups.values()):
self.release_render_group(render_group)
self.mob_to_render_groups = {}
def refresh_perspective_uniforms(self) -> None:
frame = self.frame frame = self.frame
view_matrix = frame.get_view_matrix() view_matrix = frame.get_view_matrix()
light_pos = self.light_source.get_location() light_pos = self.light_source.get_location()
cam_pos = self.frame.get_implied_camera_location() cam_pos = self.frame.get_implied_camera_location()
self.perspective_uniforms.update( self.uniforms.update(
frame_shape=frame.get_shape(), frame_shape=frame.get_shape(),
pixel_size=self.get_pixel_size(), pixel_size=self.get_pixel_size(),
view=tuple(view_matrix.T.flatten()), view=tuple(view_matrix.T.flatten()),

View file

@ -103,6 +103,7 @@ class Mobject(object):
self.target = None self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3)) self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.init_data() self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype) self._data_defaults = np.ones(1, dtype=self.data.dtype)
@ -1892,11 +1893,16 @@ class Mobject(object):
return self.shader_indices return self.shader_indices
def render(self, ctx: Context, camera_uniforms: dict): def render(self, ctx: Context, camera_uniforms: dict):
if self.data_has_changed: if self._data_has_changed or self.is_changing():
self.shader_wrappers = self.get_shader_wrapper_list(ctx) self.shader_wrappers = self.get_shader_wrapper_list(ctx)
for shader_wrapper in self.shader_wrappers:
shader_wrapper.release()
shader_wrapper.get_vao()
self._data_has_changed = False
for shader_wrapper in self.shader_wrappers: for shader_wrapper in self.shader_wrappers:
shader_wrapper.update_uniforms(camera_uniforms) shader_wrapper.uniforms.update(self.get_uniforms())
shader_wrapper.update_uniforms(self.get_uniforms) shader_wrapper.uniforms.update(camera_uniforms)
shader_wrapper.pre_render()
shader_wrapper.render() shader_wrapper.render()
# Event Handlers # Event Handlers

View file

@ -575,7 +575,8 @@ class Scene(object):
self.num_plays += 1 self.num_plays += 1
def refresh_static_mobjects(self) -> None: def refresh_static_mobjects(self) -> None:
self.camera.refresh_static_mobjects() for mobject in self.mobjects:
mobject._data_has_changed = True
def begin_animations(self, animations: Iterable[Animation]) -> None: def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations: for animation in animations:

View file

@ -47,7 +47,7 @@ class ShaderWrapper(object):
self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder self.shader_folder = shader_folder
self.uniforms = uniforms or dict() self.uniforms = dict(uniforms or {})
self.depth_test = depth_test self.depth_test = depth_test
self.render_primitive = render_primitive self.render_primitive = render_primitive
@ -168,16 +168,7 @@ class ShaderWrapper(object):
if enable: if enable:
gl.glEnable(gl.GL_CLIP_DISTANCE0) gl.glEnable(gl.GL_CLIP_DISTANCE0)
# Adding data
# Related to data and rendering
def render(self, camera_uniforms: dict):
self.update_program_uniforms(camera_uniforms)
self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane())
# TODO, generate on the fly?
assert(self.vao is not None)
self.vao.render(self.render_primitive)
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
if len(shader_wrappers) > 0: if len(shader_wrappers) > 0:
@ -221,10 +212,21 @@ class ShaderWrapper(object):
n_points += len(data) n_points += len(data)
return self return self
def update_program_uniforms(self, camera_uniforms: dict): # Related to data and rendering
def pre_render(self):
self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane())
self.update_program_uniforms()
def render(self):
# TODO, generate on the fly?
assert(self.vao is not None)
self.vao.render()
def update_program_uniforms(self):
if self.program is None: if self.program is None:
return return
for name, value in (*camera_uniforms.items(), *self.uniforms.items()): for name, value in self.uniforms.items():
if name in self.program: if name in self.program:
if isinstance(value, np.ndarray) and value.ndim > 0: if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value) value = tuple(value)
@ -249,13 +251,17 @@ class ShaderWrapper(object):
program=self.program, program=self.program,
content=[(vbo, self.vert_format, *self.vert_attributes)], content=[(vbo, self.vert_format, *self.vert_attributes)],
index_buffer=ibo, index_buffer=ibo,
mode=self.render_primitive,
) )
return self.vao return self.vao
def release(self): def release(self):
for obj in (self.vbo, self.ibo, self.vao): for obj in (self.vbo, self.ibo, self.vao):
if obj is not None: if obj is not None:
obj.release() try:
obj.release()
except AttributeError:
pass
self.vbo = None self.vbo = None
self.ibo = None self.ibo = None
self.vao = None self.vao = None
@ -319,12 +325,7 @@ class FillShaderWrapper(ShaderWrapper):
'texcoord', 'texcoord',
) )
def render(self, camera_uniforms: dict): def render(self):
# TODO, these are copied...
self.update_program_uniforms(camera_uniforms)
self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane())
#
vao = self.vao vao = self.vao
assert(vao is not None) assert(vao is not None)
winding = (len(self.vert_indices) == 0) winding = (len(self.vert_indices) == 0)