diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 8e555c40..4017b6cb 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -103,9 +103,9 @@ class Mobject(object): self.saved_state = None self.target = None self.bounding_box: Vect3Array = np.zeros((3, 3)) + self.shader_wrapper: Optional[ShaderWrapper] = None self._is_animating: bool = False self._needs_new_bounding_box: bool = True - self._shaders_initialized: bool = False self._data_has_changed: bool = True self.shader_code_replacements: dict[str, str] = dict() @@ -652,13 +652,10 @@ class Mobject(object): return self def deepcopy(self) -> Self: + parents = self.parents self.parents = [] - result.target = None - result.saved_state = None - for submob in self.get_family(): - submob._shaders_initialized = False - submob._data_has_changed = True result = copy.deepcopy(self) + self.parents = parents return result def copy(self, deep: bool = False) -> Self: @@ -691,7 +688,7 @@ class Mobject(object): # won't have changed, just directly match. result.updaters = list(self.updaters) result._data_has_changed = True - result._shaders_initialized = False + result.shader_wrapper = None family = self.get_family() for attr, value in self.__dict__.items(): @@ -1947,9 +1944,7 @@ class Mobject(object): def replace_shader_code(self, old: str, new: str) -> Self: for mob in self.get_family(): mob.shader_code_replacements[old] = new - mob._shaders_initialized = False - for mob in self.get_ancestors(): - mob._shaders_initialized = False + mob.shader_wrapper = None return self def set_color_by_code(self, glsl_code: str) -> Self: @@ -1993,8 +1988,7 @@ class Mobject(object): # For shader data - def init_shader_data(self, ctx: Context): - self.shader_indices = None + def init_shader_wrapper(self, ctx: Context): self.shader_wrapper = ShaderWrapper( ctx=ctx, vert_data=self.data, @@ -2007,15 +2001,14 @@ class Mobject(object): def refresh_shader_wrapper_id(self): for submob in self.get_family(): - if submob._shaders_initialized: + if submob.shader_wrapper is not None: submob.shader_wrapper.depth_test = submob.depth_test submob.shader_wrapper.refresh_id() return self def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper: - if not self._shaders_initialized: - self.init_shader_data(ctx) - self._shaders_initialized = True + if self.shader_wrapper is None: + self.init_shader_wrapper(ctx) return self.shader_wrapper def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: @@ -2041,7 +2034,7 @@ class Mobject(object): return self.uniforms def get_shader_vert_indices(self) -> Optional[np.ndarray]: - return self.shader_indices + return None def render(self, ctx: Context, camera_uniforms: dict): if self._data_has_changed: diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 2afff561..d465ac43 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1286,8 +1286,7 @@ class VMobject(Mobject): # For shaders - def init_shader_data(self, ctx: Context): - self.shader_indices = None + def init_shader_wrapper(self, ctx: Context): self.shader_wrapper = VShaderWrapper( ctx=ctx, vert_data=self.data, @@ -1299,7 +1298,7 @@ class VMobject(Mobject): def refresh_shader_wrapper_id(self): for submob in self.get_family(): - if submob._shaders_initialized: + if submob.shader_wrapper is not None: submob.shader_wrapper.stroke_behind = submob.stroke_behind super().refresh_shader_wrapper_id() return self