diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index cf35b3b3..bad88b34 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -105,6 +105,7 @@ class Mobject(object): self.bounding_box: Vect3Array = np.zeros((3, 3)) self._shaders_initialized: bool = False self._data_has_changed: bool = True + self.shader_code_replacements: dict[str, str] = dict() self.init_data() self._data_defaults = np.ones(1, dtype=self.data.dtype) @@ -1895,12 +1896,12 @@ class Mobject(object): # Shader code manipulation + @affects_data def replace_shader_code(self, old: str, new: str) -> Self: - # TODO, will this work with VMobject structure, given - # that it does not simpler return shader_wrappers of - # family? - for wrapper in self.get_shader_wrapper_list(): - wrapper.replace_code(old, new) + self.shader_code_replacements[old] = new + self._shaders_initialized = False + for mob in self.get_ancestors(): + mob._shaders_initialized = False return self def set_color_by_code(self, glsl_code: str) -> Self: @@ -1969,6 +1970,8 @@ class Mobject(object): self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms()) self.shader_wrapper.depth_test = self.depth_test + for old, new in self.shader_code_replacements.items(): + self.shader_wrapper.replace_code(old, new) return self.shader_wrapper def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index ce4eae34..9b27ddb9 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1292,6 +1292,10 @@ class VMobject(Mobject): self.fill_shader_wrapper, self.stroke_shader_wrapper, ] + for sw in self.shader_wrappers: + rep = self.family_members_with_points()[0] + for old, new in rep.shader_code_replacements.items(): + sw.replace_code(old, new) def refresh_shader_wrapper_id(self) -> Self: if not self._shaders_initialized: @@ -1355,8 +1359,9 @@ class VMobject(Mobject): self.stroke_shader_wrapper.read_in(stroke_datas), ] for sw in shader_wrappers: - sw.bind_to_mobject_uniforms(family[0].get_uniforms()) - sw.depth_test = family[0].depth_test + rep = family[0] # Representative family member + sw.bind_to_mobject_uniforms(rep.get_uniforms()) + sw.depth_test = rep.depth_test return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]