diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 75a4d010..1202b996 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -15,6 +15,7 @@ from manimlib.constants import * from manimlib.container.container import Container from manimlib.utils.color import color_gradient from manimlib.utils.color import interpolate_color +from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import list_update from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.paths import straight_path @@ -22,6 +23,10 @@ from manimlib.utils.simple_functions import get_parameters from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix +from manimlib.utils.shaders import get_shader_info +from manimlib.utils.shaders import shader_info_to_id +from manimlib.utils.shaders import shader_id_to_info +from manimlib.utils.shaders import is_valid_shader_info # TODO: Explain array_attrs @@ -1143,22 +1148,37 @@ class Mobject(Container): return arr def get_shader_info_list(self): - return [self.get_shader_info()] + shader_infos = it.chain( + [self.get_shader_info()], + *[ + submob.get_shader_info_list() + for submob in self.submobjects + ] + ) + batches = batch_by_property(shader_infos, shader_info_to_id) + + result = [] + for info_group, sid in batches: + shader_info = shader_id_to_info(sid) + shader_info["data"] = np.hstack([info["data"] for info in info_group]) + if is_valid_shader_info(shader_info): + result.append(shader_info) + return result def get_shader_info(self): - return { - "data": self.get_shader_data(), - "vert": self.vert_shader_file, - "geom": self.geom_shader_file, - "frag": self.frag_shader_file, - "render_primative": self.render_primative, - "texture_path": self.texture_path, - } + return get_shader_info( + data=self.get_shader_data(), + vert_file=self.vert_shader_file, + geom_file=self.geom_shader_file, + frag_file=self.frag_shader_file, + texture_path=self.texture_path, + render_primative=self.render_primative, + ) def get_shader_data(self): - # To be implemented by subclasses + # Typically to be implemented by subclasses # Must return a structured numpy array - pass + return self.shader_data # Errors def throw_error_if_no_points(self):