Rather than calling get_shader_info a bunch, remember a tempalte

This commit is contained in:
Grant Sanderson 2020-06-15 12:01:54 -07:00
parent 6a458547c3
commit c45fe52a70
2 changed files with 50 additions and 36 deletions

View file

@ -75,6 +75,11 @@ class Mobject(Container):
self.init_colors()
self.init_shader_data()
if self.is_fixed_in_frame:
self.fix_in_frame()
if self.depth_test:
self.apply_depth_test()
def __str__(self):
return str(self.name)
@ -1195,6 +1200,14 @@ class Mobject(Container):
# For shaders
def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
self.shader_info_template = get_shader_info(
vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file,
texture_paths=self.texture_paths,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
def get_blank_shader_data_array(self, size, name="shader_data"):
# If possible, try to populate an existing array, rather
@ -1219,10 +1232,7 @@ class Mobject(Container):
return self.saved_shader_info_list
shader_infos = it.chain(
[self.get_shader_info()],
*[
submob.get_shader_info_list()
for submob in self.submobjects
]
*[sm.get_shader_info_list() for sm in self.submobjects]
)
batches = batch_by_property(shader_infos, shader_info_to_id)
@ -1235,16 +1245,10 @@ class Mobject(Container):
return result
def get_shader_info(self):
return get_shader_info(
data=self.get_shader_data(),
vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file,
uniforms=self.get_shader_uniforms(),
texture_paths=self.texture_paths,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
shader_info = dict(self.shader_info_template)
shader_info["data"] = self.get_shader_data()
shader_info["uniforms"] = self.get_shader_uniforms()
return shader_info
def get_shader_uniforms(self):
return {

View file

@ -9,7 +9,6 @@ from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_smooth_cubic_bezier_handle_points
from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
from manimlib.utils.bezier import interpolate
@ -84,8 +83,8 @@ class VMobject(Mobject):
self.unit_normal_locked = False
self.triangulation_locked = False
super().__init__(**kwargs)
self.lock_unit_normal()
self.lock_triangulation()
self.lock_unit_normal(family=False)
self.lock_triangulation(family=False)
def get_group_class(self):
return VGroup
@ -676,7 +675,11 @@ class VMobject(Mobject):
return self
def refresh_unit_normal(self):
self.lock_unit_normal()
for mob in self.get_family():
mob.unit_normal_locked = False
mob.saved_unit_normal = mob.get_unit_normal()
mob.unit_normal_locked = True
return self
# Alignment
def align_points(self, vmobject):
@ -842,27 +845,29 @@ class VMobject(Mobject):
def init_shader_data(self):
self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype)
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
self.fill_shader_info_template = get_shader_info(
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
self.stroke_shader_info_template = get_shader_info(
vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file,
depth_test=self.depth_test,
render_primative=self.render_primative,
)
def get_shader_info_list(self):
if self.shader_data_is_locked:
return self.saved_shader_info_list
stroke_info = get_shader_info(
vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file,
uniforms=self.get_stroke_uniforms(),
depth_test=self.depth_test,
render_primative=self.render_primative,
)
fill_info = get_shader_info(
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
uniforms=self.get_shader_uniforms(),
depth_test=self.depth_test,
render_primative=self.render_primative,
)
fill_info = dict(self.fill_shader_info_template)
stroke_info = dict(self.stroke_shader_info_template)
fill_info["uniforms"] = self.get_shader_uniforms()
stroke_info["uniforms"] = self.get_stroke_uniforms()
back_stroke_data = []
stroke_data = []
@ -941,11 +946,16 @@ class VMobject(Mobject):
return self
def unlock_triangulation(self):
for sm in self.family_members_with_points():
for sm in self.get_family():
sm.triangulation_locked = False
def refresh_triangulation(self):
self.lock_triangulation()
for mob in self.get_family():
if mob.triangulation_locked:
mob.triangulation_locked = False
mob.saved_triangulation = mob.get_triangulation()
mob.triangulation_locked = True
return self
def get_triangulation(self, normal_vector=None):
# Figure out how to triangulate the interior to know