Rather than calling get_shader_info a bunch, remember a tempalte

This commit is contained in:
Grant Sanderson 2020-06-15 12:01:54 -07:00
parent 6a458547c3
commit c45fe52a70
2 changed files with 50 additions and 36 deletions

View file

@ -75,6 +75,11 @@ class Mobject(Container):
self.init_colors() self.init_colors()
self.init_shader_data() self.init_shader_data()
if self.is_fixed_in_frame:
self.fix_in_frame()
if self.depth_test:
self.apply_depth_test()
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@ -1195,6 +1200,14 @@ class Mobject(Container):
# For shaders # For shaders
def init_shader_data(self): def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype) self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
self.shader_info_template = get_shader_info(
vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file,
texture_paths=self.texture_paths,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
def get_blank_shader_data_array(self, size, name="shader_data"): def get_blank_shader_data_array(self, size, name="shader_data"):
# If possible, try to populate an existing array, rather # If possible, try to populate an existing array, rather
@ -1219,10 +1232,7 @@ class Mobject(Container):
return self.saved_shader_info_list return self.saved_shader_info_list
shader_infos = it.chain( shader_infos = it.chain(
[self.get_shader_info()], [self.get_shader_info()],
*[ *[sm.get_shader_info_list() for sm in self.submobjects]
submob.get_shader_info_list()
for submob in self.submobjects
]
) )
batches = batch_by_property(shader_infos, shader_info_to_id) batches = batch_by_property(shader_infos, shader_info_to_id)
@ -1235,16 +1245,10 @@ class Mobject(Container):
return result return result
def get_shader_info(self): def get_shader_info(self):
return get_shader_info( shader_info = dict(self.shader_info_template)
data=self.get_shader_data(), shader_info["data"] = self.get_shader_data()
vert_file=self.vert_shader_file, shader_info["uniforms"] = self.get_shader_uniforms()
geom_file=self.geom_shader_file, return shader_info
frag_file=self.frag_shader_file,
uniforms=self.get_shader_uniforms(),
texture_paths=self.texture_paths,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
def get_shader_uniforms(self): def get_shader_uniforms(self):
return { return {

View file

@ -9,7 +9,6 @@ from manimlib.constants import *
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Point from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_smooth_cubic_bezier_handle_points
from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
@ -84,8 +83,8 @@ class VMobject(Mobject):
self.unit_normal_locked = False self.unit_normal_locked = False
self.triangulation_locked = False self.triangulation_locked = False
super().__init__(**kwargs) super().__init__(**kwargs)
self.lock_unit_normal() self.lock_unit_normal(family=False)
self.lock_triangulation() self.lock_triangulation(family=False)
def get_group_class(self): def get_group_class(self):
return VGroup return VGroup
@ -676,7 +675,11 @@ class VMobject(Mobject):
return self return self
def refresh_unit_normal(self): def refresh_unit_normal(self):
self.lock_unit_normal() for mob in self.get_family():
mob.unit_normal_locked = False
mob.saved_unit_normal = mob.get_unit_normal()
mob.unit_normal_locked = True
return self
# Alignment # Alignment
def align_points(self, vmobject): def align_points(self, vmobject):
@ -842,27 +845,29 @@ class VMobject(Mobject):
def init_shader_data(self): def init_shader_data(self):
self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype) self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype)
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype) self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
self.fill_shader_info_template = get_shader_info(
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
self.stroke_shader_info_template = get_shader_info(
vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
def get_shader_info_list(self): def get_shader_info_list(self):
if self.shader_data_is_locked: if self.shader_data_is_locked:
return self.saved_shader_info_list return self.saved_shader_info_list
stroke_info = get_shader_info( fill_info = dict(self.fill_shader_info_template)
vert_file=self.stroke_vert_shader_file, stroke_info = dict(self.stroke_shader_info_template)
geom_file=self.stroke_geom_shader_file, fill_info["uniforms"] = self.get_shader_uniforms()
frag_file=self.stroke_frag_shader_file, stroke_info["uniforms"] = self.get_stroke_uniforms()
uniforms=self.get_stroke_uniforms(),
depth_test=self.depth_test,
render_primative=self.render_primative,
)
fill_info = get_shader_info(
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
uniforms=self.get_shader_uniforms(),
depth_test=self.depth_test,
render_primative=self.render_primative,
)
back_stroke_data = [] back_stroke_data = []
stroke_data = [] stroke_data = []
@ -941,11 +946,16 @@ class VMobject(Mobject):
return self return self
def unlock_triangulation(self): def unlock_triangulation(self):
for sm in self.family_members_with_points(): for sm in self.get_family():
sm.triangulation_locked = False sm.triangulation_locked = False
def refresh_triangulation(self): def refresh_triangulation(self):
self.lock_triangulation() for mob in self.get_family():
if mob.triangulation_locked:
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
return self
def get_triangulation(self, normal_vector=None): def get_triangulation(self, normal_vector=None):
# Figure out how to triangulate the interior to know # Figure out how to triangulate the interior to know