Only initialize ShaderWrappers as needed

This commit is contained in:
Grant Sanderson 2023-01-25 10:37:12 -08:00
parent 8c1e5f3b42
commit 80729c0cb8
2 changed files with 11 additions and 3 deletions

View file

@ -101,6 +101,7 @@ class Mobject(object):
self.saved_state = None self.saved_state = None
self.target = None self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3)) self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False
self.init_data() self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype) self._data_defaults = np.ones(1, dtype=self.data.dtype)
@ -109,7 +110,6 @@ class Mobject(object):
self.init_event_listners() self.init_event_listners()
self.init_points() self.init_points()
self.init_colors() self.init_colors()
self.init_shader_data()
if self.depth_test: if self.depth_test:
self.apply_depth_test() self.apply_depth_test()
@ -1843,7 +1843,6 @@ class Mobject(object):
# For shader data # For shader data
def init_shader_data(self): def init_shader_data(self):
# TODO, only call this when needed?
self.shader_indices = np.zeros(0) self.shader_indices = np.zeros(0)
self.shader_wrapper = ShaderWrapper( self.shader_wrapper = ShaderWrapper(
vert_data=self.data, vert_data=self.data,
@ -1854,10 +1853,15 @@ class Mobject(object):
) )
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self):
self.shader_wrapper.refresh_id() if self._shaders_initialized:
self.shader_wrapper.refresh_id()
return self return self
def get_shader_wrapper(self) -> ShaderWrapper: def get_shader_wrapper(self) -> ShaderWrapper:
if not self._shaders_initialized:
self.init_shader_data()
self._shaders_initialized = True
self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.uniforms = self.get_uniforms() self.shader_wrapper.uniforms = self.get_uniforms()

View file

@ -1196,6 +1196,10 @@ class VMobject(Mobject):
return self return self
def get_shader_wrapper_list(self) -> list[ShaderWrapper]: def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
if not self._shaders_initialized:
self.init_shader_data()
self._shaders_initialized = True
family = self.family_members_with_points() family = self.family_members_with_points()
if not family: if not family:
return [] return []