FIx Mobject.replace_shader_code

This commit is contained in:
Grant Sanderson 2023-02-02 20:47:55 -08:00
parent d10745a379
commit c4777015fc
2 changed files with 15 additions and 7 deletions

View file

@ -105,6 +105,7 @@ class Mobject(object):
self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict()
self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype)
@ -1895,12 +1896,12 @@ class Mobject(object):
# Shader code manipulation
@affects_data
def replace_shader_code(self, old: str, new: str) -> Self:
# TODO, will this work with VMobject structure, given
# that it does not simpler return shader_wrappers of
# family?
for wrapper in self.get_shader_wrapper_list():
wrapper.replace_code(old, new)
self.shader_code_replacements[old] = new
self._shaders_initialized = False
for mob in self.get_ancestors():
mob._shaders_initialized = False
return self
def set_color_by_code(self, glsl_code: str) -> Self:
@ -1969,6 +1970,8 @@ class Mobject(object):
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test
for old, new in self.shader_code_replacements.items():
self.shader_wrapper.replace_code(old, new)
return self.shader_wrapper
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:

View file

@ -1292,6 +1292,10 @@ class VMobject(Mobject):
self.fill_shader_wrapper,
self.stroke_shader_wrapper,
]
for sw in self.shader_wrappers:
rep = self.family_members_with_points()[0]
for old, new in rep.shader_code_replacements.items():
sw.replace_code(old, new)
def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized:
@ -1355,8 +1359,9 @@ class VMobject(Mobject):
self.stroke_shader_wrapper.read_in(stroke_datas),
]
for sw in shader_wrappers:
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
sw.depth_test = family[0].depth_test
rep = family[0] # Representative family member
sw.bind_to_mobject_uniforms(rep.get_uniforms())
sw.depth_test = rep.depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]