Have mobject.get_shader_info_list handle all of its submobjects

This commit is contained in:
Grant Sanderson 2020-02-17 12:15:53 -08:00
parent 9d4b16d03f
commit 47daf8e7f7

View file

@ -15,6 +15,7 @@ from manimlib.constants import *
from manimlib.container.container import Container from manimlib.container.container import Container
from manimlib.utils.color import color_gradient from manimlib.utils.color import color_gradient
from manimlib.utils.color import interpolate_color from manimlib.utils.color import interpolate_color
from manimlib.utils.iterables import batch_by_property
from manimlib.utils.iterables import list_update from manimlib.utils.iterables import list_update
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.paths import straight_path from manimlib.utils.paths import straight_path
@ -22,6 +23,10 @@ from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix from manimlib.utils.space_ops import rotation_matrix
from manimlib.utils.shaders import get_shader_info
from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import shader_id_to_info
from manimlib.utils.shaders import is_valid_shader_info
# TODO: Explain array_attrs # TODO: Explain array_attrs
@ -1143,22 +1148,37 @@ class Mobject(Container):
return arr return arr
def get_shader_info_list(self): def get_shader_info_list(self):
return [self.get_shader_info()] shader_infos = it.chain(
[self.get_shader_info()],
*[
submob.get_shader_info_list()
for submob in self.submobjects
]
)
batches = batch_by_property(shader_infos, shader_info_to_id)
result = []
for info_group, sid in batches:
shader_info = shader_id_to_info(sid)
shader_info["data"] = np.hstack([info["data"] for info in info_group])
if is_valid_shader_info(shader_info):
result.append(shader_info)
return result
def get_shader_info(self): def get_shader_info(self):
return { return get_shader_info(
"data": self.get_shader_data(), data=self.get_shader_data(),
"vert": self.vert_shader_file, vert_file=self.vert_shader_file,
"geom": self.geom_shader_file, geom_file=self.geom_shader_file,
"frag": self.frag_shader_file, frag_file=self.frag_shader_file,
"render_primative": self.render_primative, texture_path=self.texture_path,
"texture_path": self.texture_path, render_primative=self.render_primative,
} )
def get_shader_data(self): def get_shader_data(self):
# To be implemented by subclasses # Typically to be implemented by subclasses
# Must return a structured numpy array # Must return a structured numpy array
pass return self.shader_data
# Errors # Errors
def throw_error_if_no_points(self): def throw_error_if_no_points(self):