Instead of tracking _shaders_initialized, just check if self.shader_wrapper is None

This commit is contained in:
Grant Sanderson 2024-08-20 10:48:43 -05:00
parent 0ac9ee1fbf
commit e0191d81d9
2 changed files with 12 additions and 20 deletions

View file

@ -103,9 +103,9 @@ class Mobject(object):
self.saved_state = None
self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3))
self.shader_wrapper: Optional[ShaderWrapper] = None
self._is_animating: bool = False
self._needs_new_bounding_box: bool = True
self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict()
@ -652,13 +652,10 @@ class Mobject(object):
return self
def deepcopy(self) -> Self:
parents = self.parents
self.parents = []
result.target = None
result.saved_state = None
for submob in self.get_family():
submob._shaders_initialized = False
submob._data_has_changed = True
result = copy.deepcopy(self)
self.parents = parents
return result
def copy(self, deep: bool = False) -> Self:
@ -691,7 +688,7 @@ class Mobject(object):
# won't have changed, just directly match.
result.updaters = list(self.updaters)
result._data_has_changed = True
result._shaders_initialized = False
result.shader_wrapper = None
family = self.get_family()
for attr, value in self.__dict__.items():
@ -1947,9 +1944,7 @@ class Mobject(object):
def replace_shader_code(self, old: str, new: str) -> Self:
for mob in self.get_family():
mob.shader_code_replacements[old] = new
mob._shaders_initialized = False
for mob in self.get_ancestors():
mob._shaders_initialized = False
mob.shader_wrapper = None
return self
def set_color_by_code(self, glsl_code: str) -> Self:
@ -1993,8 +1988,7 @@ class Mobject(object):
# For shader data
def init_shader_data(self, ctx: Context):
self.shader_indices = None
def init_shader_wrapper(self, ctx: Context):
self.shader_wrapper = ShaderWrapper(
ctx=ctx,
vert_data=self.data,
@ -2007,15 +2001,14 @@ class Mobject(object):
def refresh_shader_wrapper_id(self):
for submob in self.get_family():
if submob._shaders_initialized:
if submob.shader_wrapper is not None:
submob.shader_wrapper.depth_test = submob.depth_test
submob.shader_wrapper.refresh_id()
return self
def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
if not self._shaders_initialized:
self.init_shader_data(ctx)
self._shaders_initialized = True
if self.shader_wrapper is None:
self.init_shader_wrapper(ctx)
return self.shader_wrapper
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
@ -2041,7 +2034,7 @@ class Mobject(object):
return self.uniforms
def get_shader_vert_indices(self) -> Optional[np.ndarray]:
return self.shader_indices
return None
def render(self, ctx: Context, camera_uniforms: dict):
if self._data_has_changed:

View file

@ -1286,8 +1286,7 @@ class VMobject(Mobject):
# For shaders
def init_shader_data(self, ctx: Context):
self.shader_indices = None
def init_shader_wrapper(self, ctx: Context):
self.shader_wrapper = VShaderWrapper(
ctx=ctx,
vert_data=self.data,
@ -1299,7 +1298,7 @@ class VMobject(Mobject):
def refresh_shader_wrapper_id(self):
for submob in self.get_family():
if submob._shaders_initialized:
if submob.shader_wrapper is not None:
submob.shader_wrapper.stroke_behind = submob.stroke_behind
super().refresh_shader_wrapper_id()
return self