Speed improvements. Camera saves vbo for static mobjects, data in shader_data_info is saved and concatenated in raw form

This commit is contained in:
Grant Sanderson 2020-06-26 19:29:34 -07:00
parent 72bfb0047e
commit 212cdbb4d2
5 changed files with 74 additions and 59 deletions

View file

@ -168,6 +168,7 @@ class Camera(object):
self.init_shaders()
self.init_textures()
self.init_light_source()
self.static_mobjects_to_shader_info_list = {}
def init_frame(self):
self.frame = CameraFrame(**self.frame_config)
@ -316,34 +317,59 @@ class Camera(object):
return fc + scale * np.array([(px - pw / 2), (py - ph / 2), 0])
# Rendering
def set_mobjects_as_static(self, *mobjects):
for mob in mobjects:
info_list = mob.get_shader_info_list()
for info in info_list:
info["vbo"] = self.ctx.buffer(info["raw_data"])
self.static_mobjects_to_shader_info_list[id(mob)] = info_list
def release_static_mobjects(self):
for mob, info_list in self.static_mobjects_to_shader_info_list.items():
for info in info_list:
info["vbo"].release()
self.static_mobjects_to_shader_info_list = {}
def capture(self, *mobjects, **kwargs):
self.refresh_shader_uniforms()
self.refresh_perspective_uniforms()
shader_infos = it.chain(*[mob.get_shader_info_list() for mob in mobjects])
batches = batch_by_property(shader_infos, shader_info_to_id)
# shader_infos = it.chain(*[mob.get_shader_info_list() for mob in mobjects])
# batches = batch_by_property(shader_infos, shader_info_to_id)
for info_group, sid in batches:
data = np.hstack([info["data"] for info in info_group])
shader = self.get_shader(info_group[0])
render_primative = int(info_group[0]["render_primative"])
depth_test = info_group[0]["depth_test"]
self.render(shader, data, render_primative, depth_test)
# for shader_info_group, sid in batches:
for mobject in mobjects:
try:
info_list = self.static_mobjects_to_shader_info_list[id(mobject)]
except KeyError:
info_list = mobject.get_shader_info_list()
def render(self, shader, data, render_primative, depth_test=False):
if data is None or len(data) == 0:
for shader_info in info_list:
self.render(shader_info)
def render(self, shader_info):
raw_data = shader_info["raw_data"]
if not raw_data:
return
shader = self.get_shader(shader_info)
if shader is None:
return
if depth_test:
self.set_perspective_uniforms(shader)
if shader_info["depth_test"]:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
vbo = self.ctx.buffer(data.tobytes())
vao = self.ctx.simple_vertex_array(shader, vbo, *data.dtype.names)
vao.render(render_primative)
vbo.release()
if "vbo" in shader_info:
vbo = shader_info["vbo"]
else:
vbo = self.ctx.buffer(raw_data)
vao = self.ctx.simple_vertex_array(shader, vbo, *shader_info["attributes"])
vao.render(int(shader_info["render_primative"]))
vao.release()
if "vbo" not in shader_info:
vbo.release()
# Shaders
def init_shaders(self):
@ -356,7 +382,6 @@ class Camera(object):
# Create shader program for the first time, then cache
# in the id_to_shader dictionary
shader = self.ctx.program(**shader_info_to_program_code(shader_info))
self.set_shader_uniforms(shader)
for name, path in shader_info["texture_paths"].items():
tid = self.get_texture_id(path)
shader[name].value = tid
@ -365,10 +390,14 @@ class Camera(object):
self.id_to_shader[sid] = shader
return self.id_to_shader[sid]
def set_shader_uniforms(self, shader):
if shader is None:
return
def set_perspective_uniforms(self, shader):
for key, value in self.perspective_uniforms.items():
try:
shader[key].value = value
except KeyError:
pass
def refresh_perspective_uniforms(self):
pw, ph = self.get_pixel_shape()
fw, fh = self.frame.get_shape()
# TODO, this should probably be a mobject uniform, with
@ -377,7 +406,7 @@ class Camera(object):
transform = self.frame.get_inverse_camera_position_matrix()
light = self.light_source.get_location()
transformed_light = np.dot(transform, [*light, 1])[:3]
mapping = {
self.perspective_uniforms = {
'to_screen_space': tuple(transform.T.flatten()),
'frame_shape': self.frame.get_shape(),
'focal_distance': self.frame.get_focal_distance(),
@ -385,16 +414,6 @@ class Camera(object):
'light_source_position': tuple(transformed_light),
}
for key, value in mapping.items():
try:
shader[key].value = value
except KeyError:
pass
def refresh_shader_uniforms(self):
for sid, shader in self.id_to_shader.items():
self.set_shader_uniforms(shader)
def init_textures(self):
self.path_to_texture_id = {}

View file

@ -68,7 +68,6 @@ class Mobject(Container):
self.time_based_updaters = []
self.non_time_updaters = []
self.updating_suspended = False
self.shader_data_is_locked = False
self.reset_points()
self.init_points()
@ -1216,17 +1215,7 @@ class Mobject(Container):
return new_arr
return arr
def lock_shader_data(self):
self.shader_data_is_locked = False
self.saved_shader_info_list = self.get_shader_info_list()
self.shader_data_is_locked = True
def unlock_shader_data(self):
self.shader_data_is_locked = False
def get_shader_info_list(self):
if self.shader_data_is_locked:
return self.saved_shader_info_list
shader_infos = it.chain(
[self.get_shader_info()],
*[sm.get_shader_info_list() for sm in self.submobjects]
@ -1236,14 +1225,16 @@ class Mobject(Container):
result = []
for info_group, sid in batches:
shader_info = info_group[0]
shader_info["data"] = np.hstack([info["data"] for info in info_group])
shader_info["raw_data"] = b''.join([info["raw_data"] for info in info_group])
if is_valid_shader_info(shader_info):
result.append(shader_info)
return result
def get_shader_info(self):
shader_info = dict(self.shader_info_template)
shader_info["data"] = self.get_shader_data()
data = self.get_shader_data()
shader_info["raw_data"] = data.tobytes()
shader_info["attributes"] = data.dtype.names
shader_info["uniforms"] = self.get_shader_uniforms()
return shader_info

View file

@ -847,6 +847,7 @@ class VMobject(Mobject):
self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype)
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
self.fill_shader_info_template = get_shader_info(
attributes=self.fill_data.dtype.names,
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
@ -854,6 +855,7 @@ class VMobject(Mobject):
render_primative=self.render_primative,
)
self.stroke_shader_info_template = get_shader_info(
attributes=self.stroke_data.dtype.names,
vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file,
@ -881,7 +883,7 @@ class VMobject(Mobject):
fill_opacity = submob.get_fill_opacity()
if fill_opacity > 0:
fill_data.append(submob.get_fill_shader_data())
fill_data.append(submob.get_fill_shader_data().tobytes())
if stroke_width > 0 and stroke_opacity > 0:
if submob.draw_stroke_behind_fill:
@ -889,18 +891,18 @@ class VMobject(Mobject):
else:
data = stroke_data
new_data = submob.get_stroke_shader_data()
data.append(new_data)
data.append(new_data.tobytes())
result = []
if back_stroke_data:
back_stroke_info = dict(stroke_info) # Copy
back_stroke_info["data"] = np.hstack(back_stroke_data)
back_stroke_info["raw_data"] = b''.join(back_stroke_data)
result.append(back_stroke_info)
if fill_data:
fill_info["data"] = np.hstack(fill_data)
fill_info["raw_data"] = b''.join(fill_data)
result.append(fill_info)
if stroke_data:
stroke_info["data"] = np.hstack(stroke_data)
stroke_info["raw_data"] = b''.join(stroke_data)
result.append(stroke_info)
return result

View file

@ -375,11 +375,10 @@ class Scene(Container):
continue
if mobject.get_family_updaters():
continue
mobject.lock_shader_data()
self.camera.set_mobjects_as_static(mobject)
def unlock_mobject_data(self):
for mobject in self.mobjects:
mobject.unlock_shader_data()
self.camera.release_static_mobjects()
def begin_animations(self, animations):
for animation in animations:

View file

@ -15,7 +15,9 @@ from manimlib.constants import SHADER_DIR
SHADER_INFO_KEYS = [
# A structred array caring all of the points/color/lighting/etc. information
# needed for the shader.
"data",
"raw_data",
# List of variable names corresponding to inputs of vertex shader
"attributes",
# Filename of vetex shader
"vert",
# Filename of geometry shader, if there is one
@ -33,11 +35,12 @@ SHADER_INFO_KEYS = [
"render_primative",
]
# Exclude data
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:]
# Exclude raw_data
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[2:]
def get_shader_info(data=None,
def get_shader_info(raw_data=None,
attributes=None,
vert_file=None,
geom_file=None,
frag_file=None,
@ -47,7 +50,8 @@ def get_shader_info(data=None,
render_primative=moderngl.TRIANGLE_STRIP,
):
result = {
"data": data,
"raw_data": raw_data,
"attributes": attributes,
"vert": vert_file,
"geom": geom_file,
"frag": frag_file,
@ -62,9 +66,9 @@ def get_shader_info(data=None,
def is_valid_shader_info(shader_info):
data = shader_info["data"]
raw_data = shader_info["raw_data"]
return all([
data is not None and len(data) > 0,
raw_data is not None and len(raw_data) > 0,
shader_info["vert"],
shader_info["frag"],
])