diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 7596dd5c..5c6d22ab 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1944,8 +1944,9 @@ class Mobject(object): @affects_data def replace_shader_code(self, old: str, new: str) -> Self: - self.shader_code_replacements[old] = new - self._shaders_initialized = False + for mob in self.get_family(): + mob.shader_code_replacements[old] = new + mob._shaders_initialized = False for mob in self.get_ancestors(): mob._shaders_initialized = False return self @@ -2000,48 +2001,45 @@ class Mobject(object): texture_paths=self.texture_paths, depth_test=self.depth_test, render_primitive=self.render_primitive, + code_replacements=self.shader_code_replacements, ) def refresh_shader_wrapper_id(self): - if self._shaders_initialized: - self.shader_wrapper.refresh_id() + for submob in self.get_family(): + if submob._shaders_initialized: + submob.shader_wrapper.depth_test = submob.depth_test + submob.shader_wrapper.refresh_id() return self def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper: if not self._shaders_initialized: self.init_shader_data(ctx) self._shaders_initialized = True - - self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms()) - self.shader_wrapper.depth_test = self.depth_test - for old, new in self.shader_code_replacements.items(): - self.shader_wrapper.replace_code(old, new) return self.shader_wrapper def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: family = self.family_members_with_points() - for submob in family: - submob.get_shader_wrapper(ctx) - batches = batch_by_property(family, lambda submob: submob.shader_wrapper.get_id()) + batches = batch_by_property(family, lambda sm: sm.get_shader_wrapper(ctx).get_id()) result = [] for submobs, sid in batches: shader_wrapper = submobs[0].shader_wrapper - data_list = [sm.get_shader_data() for sm in submobs] - indices_list = [sm.get_shader_vert_indices() for sm in submobs] - if indices_list[0] is None: - indices_list = None - shader_wrapper.read_in(data_list, indices_list) + data_list = list(it.chain(*(sm.get_shader_data() for sm in submobs))) + shader_wrapper.read_in(data_list, indices_list=None) result.append(shader_wrapper) return result - def get_shader_data(self): - return self.data + def get_shader_data(self) -> Iterable[np.ndarray]: + indices = self.get_shader_vert_indices() + if indices is not None: + return [self.data[indices]] + else: + return [self.data] def get_uniforms(self): return self.uniforms - def get_shader_vert_indices(self): + def get_shader_vert_indices(self) -> Optional[np.ndarray]: return self.shader_indices def render(self, ctx: Context, camera_uniforms: dict): diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index d8fd2f08..82eaecd3 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -191,6 +191,7 @@ class VMobject(Mobject): if background is not None: for mob in self.get_family(recurse): mob.stroke_behind = background + mob.refresh_shader_wrapper_id() if flat is not None: self.set_flat_stroke(flat) @@ -1285,44 +1286,27 @@ class VMobject(Mobject): # For shaders def init_shader_data(self, ctx: Context): - self.shader_indices = np.zeros(0) + self.shader_indices = None self.shader_wrapper = VShaderWrapper( ctx=ctx, vert_data=self.data, mobject_uniforms=self.uniforms, + code_replacements=self.shader_code_replacements, + stroke_behind=self.stroke_behind, + depth_test=self.depth_test ) - def get_shader_vert_indices(self): - return None + def refresh_shader_wrapper_id(self): + for submob in self.get_family(): + if submob._shaders_initialized: + submob.shader_wrapper.stroke_behind = submob.stroke_behind + super().refresh_shader_wrapper_id() + return self - def get_shader_data(self): - # This should only come up when VMobjects appear together in a group - return np.hstack([self.data, self.data[-1:]]) - - def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: - if not self._shaders_initialized: - self.init_shader_data(ctx) - self._shaders_initialized = True - - family = self.family_members_with_points() - if not family: - return [] - - stroke_behind = False - for submob in family: - # Maybe do this on set points instead? Or on noting changed data? - submob.data["base_normal"][0::2] = submob.data["point"][0] - if submob.stroke_behind: - stroke_behind = True - - self.shader_wrapper.read_in( - list(it.chain(*([sm.data, sm.data[-1:]] for sm in family))) - ) - rep = family[0] # Representative family member - self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms()) - self.shader_wrapper.depth_test = rep.depth_test - self.shader_wrapper.stroke_behind = stroke_behind - return [self.shader_wrapper] + def get_shader_data(self) -> Iterable[np.ndarray]: + # Do we want this elsewhere? Say whenever points are refreshed or something? + self.data["base_normal"][0::2] = self.data["point"][0] + return [self.data, self.data[-1:]] class VGroup(Group, VMobject, Generic[SubVmobjectType]): diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index c46ae76d..c2a05bb7 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -388,7 +388,7 @@ class Scene(object): """ batches = batch_by_property( self.mobjects, - lambda m: str(type(m)) + str(m.get_uniforms()) + lambda m: m.get_shader_wrapper(self.camera.ctx).get_id() ) for group in self.render_groups: diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index 7c4d4bea..d811afcc 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -41,6 +41,7 @@ class ShaderWrapper(object): texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. depth_test: bool = False, render_primitive: int = moderngl.TRIANGLE_STRIP, + code_replacements: dict[str, str] = dict(), ): self.ctx = ctx self.vert_data = vert_data @@ -48,13 +49,15 @@ class ShaderWrapper(object): self.shader_folder = shader_folder self.depth_test = depth_test self.render_primitive = render_primitive + self.texture_names_to_ids = dict() self.program_uniform_mirror: UniformDict = dict() self.bind_to_mobject_uniforms(mobject_uniforms or dict()) self.init_program_code() + for old, new in code_replacements.items(): + self.replace_code(old, new) self.init_program() - self.texture_names_to_ids = dict() if texture_paths is not None: self.init_textures(texture_paths) self.init_vertex_objects() @@ -98,19 +101,14 @@ class ShaderWrapper(object): def get_id(self) -> str: return self.id - def create_id(self) -> str: - # A unique id for a shader - program_id = hash("".join(map(str, self.program_code.values()))) - return "|".join(map(str, [ - program_id, + def refresh_id(self) -> None: + self.id = hash("".join(map(str, [ + "".join(map(str, self.program_code.values())), self.mobject_uniforms, self.depth_test, self.render_primitive, self.texture_names_to_ids, - ])) - - def refresh_id(self) -> None: - self.id = self.create_id() + ]))) def replace_code(self, old: str, new: str) -> None: code_map = self.program_code @@ -213,10 +211,12 @@ class VShaderWrapper(ShaderWrapper): mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. depth_test: bool = False, - # render_primitive: int = moderngl.TRIANGLES, render_primitive: int = moderngl.TRIANGLE_STRIP, + code_replacements: dict[str, str] = dict(), stroke_behind: bool = False, ): + self.stroke_behind = stroke_behind + self.fill_canvas = get_fill_canvas(ctx) super().__init__( ctx=ctx, vert_data=vert_data, @@ -225,9 +225,8 @@ class VShaderWrapper(ShaderWrapper): texture_paths=texture_paths, depth_test=depth_test, render_primitive=render_primitive, + code_replacements=code_replacements, ) - self.stroke_behind = stroke_behind - self.fill_canvas = get_fill_canvas(self.ctx) def init_program_code(self) -> None: self.program_code = { @@ -312,6 +311,10 @@ class VShaderWrapper(ShaderWrapper): def set_backstroke(self, value: bool = True): self.stroke_behind = value + def refresh_id(self): + super().refresh_id() + self.id = hash(str(self.id) + str(self.stroke_behind)) + # TODO, motidify read in to handle triangulation case for non-winding fill? # Rendering