mirror of
https://github.com/3b1b/manim.git
synced 2025-11-14 13:17:44 +00:00
Speed improvements. Camera saves vbo for static mobjects, data in shader_data_info is saved and concatenated in raw form
This commit is contained in:
parent
72bfb0047e
commit
212cdbb4d2
5 changed files with 74 additions and 59 deletions
|
|
@ -168,6 +168,7 @@ class Camera(object):
|
|||
self.init_shaders()
|
||||
self.init_textures()
|
||||
self.init_light_source()
|
||||
self.static_mobjects_to_shader_info_list = {}
|
||||
|
||||
def init_frame(self):
|
||||
self.frame = CameraFrame(**self.frame_config)
|
||||
|
|
@ -316,34 +317,59 @@ class Camera(object):
|
|||
return fc + scale * np.array([(px - pw / 2), (py - ph / 2), 0])
|
||||
|
||||
# Rendering
|
||||
def set_mobjects_as_static(self, *mobjects):
|
||||
for mob in mobjects:
|
||||
info_list = mob.get_shader_info_list()
|
||||
for info in info_list:
|
||||
info["vbo"] = self.ctx.buffer(info["raw_data"])
|
||||
self.static_mobjects_to_shader_info_list[id(mob)] = info_list
|
||||
|
||||
def release_static_mobjects(self):
|
||||
for mob, info_list in self.static_mobjects_to_shader_info_list.items():
|
||||
for info in info_list:
|
||||
info["vbo"].release()
|
||||
self.static_mobjects_to_shader_info_list = {}
|
||||
|
||||
def capture(self, *mobjects, **kwargs):
|
||||
self.refresh_shader_uniforms()
|
||||
self.refresh_perspective_uniforms()
|
||||
|
||||
shader_infos = it.chain(*[mob.get_shader_info_list() for mob in mobjects])
|
||||
batches = batch_by_property(shader_infos, shader_info_to_id)
|
||||
# shader_infos = it.chain(*[mob.get_shader_info_list() for mob in mobjects])
|
||||
# batches = batch_by_property(shader_infos, shader_info_to_id)
|
||||
|
||||
for info_group, sid in batches:
|
||||
data = np.hstack([info["data"] for info in info_group])
|
||||
shader = self.get_shader(info_group[0])
|
||||
render_primative = int(info_group[0]["render_primative"])
|
||||
depth_test = info_group[0]["depth_test"]
|
||||
self.render(shader, data, render_primative, depth_test)
|
||||
# for shader_info_group, sid in batches:
|
||||
for mobject in mobjects:
|
||||
try:
|
||||
info_list = self.static_mobjects_to_shader_info_list[id(mobject)]
|
||||
except KeyError:
|
||||
info_list = mobject.get_shader_info_list()
|
||||
|
||||
def render(self, shader, data, render_primative, depth_test=False):
|
||||
if data is None or len(data) == 0:
|
||||
for shader_info in info_list:
|
||||
self.render(shader_info)
|
||||
|
||||
def render(self, shader_info):
|
||||
raw_data = shader_info["raw_data"]
|
||||
if not raw_data:
|
||||
return
|
||||
|
||||
shader = self.get_shader(shader_info)
|
||||
if shader is None:
|
||||
return
|
||||
if depth_test:
|
||||
self.set_perspective_uniforms(shader)
|
||||
|
||||
if shader_info["depth_test"]:
|
||||
self.ctx.enable(moderngl.DEPTH_TEST)
|
||||
else:
|
||||
self.ctx.disable(moderngl.DEPTH_TEST)
|
||||
|
||||
vbo = self.ctx.buffer(data.tobytes())
|
||||
vao = self.ctx.simple_vertex_array(shader, vbo, *data.dtype.names)
|
||||
vao.render(render_primative)
|
||||
vbo.release()
|
||||
if "vbo" in shader_info:
|
||||
vbo = shader_info["vbo"]
|
||||
else:
|
||||
vbo = self.ctx.buffer(raw_data)
|
||||
vao = self.ctx.simple_vertex_array(shader, vbo, *shader_info["attributes"])
|
||||
vao.render(int(shader_info["render_primative"]))
|
||||
vao.release()
|
||||
if "vbo" not in shader_info:
|
||||
vbo.release()
|
||||
|
||||
# Shaders
|
||||
def init_shaders(self):
|
||||
|
|
@ -356,7 +382,6 @@ class Camera(object):
|
|||
# Create shader program for the first time, then cache
|
||||
# in the id_to_shader dictionary
|
||||
shader = self.ctx.program(**shader_info_to_program_code(shader_info))
|
||||
self.set_shader_uniforms(shader)
|
||||
for name, path in shader_info["texture_paths"].items():
|
||||
tid = self.get_texture_id(path)
|
||||
shader[name].value = tid
|
||||
|
|
@ -365,10 +390,14 @@ class Camera(object):
|
|||
self.id_to_shader[sid] = shader
|
||||
return self.id_to_shader[sid]
|
||||
|
||||
def set_shader_uniforms(self, shader):
|
||||
if shader is None:
|
||||
return
|
||||
def set_perspective_uniforms(self, shader):
|
||||
for key, value in self.perspective_uniforms.items():
|
||||
try:
|
||||
shader[key].value = value
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def refresh_perspective_uniforms(self):
|
||||
pw, ph = self.get_pixel_shape()
|
||||
fw, fh = self.frame.get_shape()
|
||||
# TODO, this should probably be a mobject uniform, with
|
||||
|
|
@ -377,7 +406,7 @@ class Camera(object):
|
|||
transform = self.frame.get_inverse_camera_position_matrix()
|
||||
light = self.light_source.get_location()
|
||||
transformed_light = np.dot(transform, [*light, 1])[:3]
|
||||
mapping = {
|
||||
self.perspective_uniforms = {
|
||||
'to_screen_space': tuple(transform.T.flatten()),
|
||||
'frame_shape': self.frame.get_shape(),
|
||||
'focal_distance': self.frame.get_focal_distance(),
|
||||
|
|
@ -385,16 +414,6 @@ class Camera(object):
|
|||
'light_source_position': tuple(transformed_light),
|
||||
}
|
||||
|
||||
for key, value in mapping.items():
|
||||
try:
|
||||
shader[key].value = value
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def refresh_shader_uniforms(self):
|
||||
for sid, shader in self.id_to_shader.items():
|
||||
self.set_shader_uniforms(shader)
|
||||
|
||||
def init_textures(self):
|
||||
self.path_to_texture_id = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ class Mobject(Container):
|
|||
self.time_based_updaters = []
|
||||
self.non_time_updaters = []
|
||||
self.updating_suspended = False
|
||||
self.shader_data_is_locked = False
|
||||
|
||||
self.reset_points()
|
||||
self.init_points()
|
||||
|
|
@ -1216,17 +1215,7 @@ class Mobject(Container):
|
|||
return new_arr
|
||||
return arr
|
||||
|
||||
def lock_shader_data(self):
|
||||
self.shader_data_is_locked = False
|
||||
self.saved_shader_info_list = self.get_shader_info_list()
|
||||
self.shader_data_is_locked = True
|
||||
|
||||
def unlock_shader_data(self):
|
||||
self.shader_data_is_locked = False
|
||||
|
||||
def get_shader_info_list(self):
|
||||
if self.shader_data_is_locked:
|
||||
return self.saved_shader_info_list
|
||||
shader_infos = it.chain(
|
||||
[self.get_shader_info()],
|
||||
*[sm.get_shader_info_list() for sm in self.submobjects]
|
||||
|
|
@ -1236,14 +1225,16 @@ class Mobject(Container):
|
|||
result = []
|
||||
for info_group, sid in batches:
|
||||
shader_info = info_group[0]
|
||||
shader_info["data"] = np.hstack([info["data"] for info in info_group])
|
||||
shader_info["raw_data"] = b''.join([info["raw_data"] for info in info_group])
|
||||
if is_valid_shader_info(shader_info):
|
||||
result.append(shader_info)
|
||||
return result
|
||||
|
||||
def get_shader_info(self):
|
||||
shader_info = dict(self.shader_info_template)
|
||||
shader_info["data"] = self.get_shader_data()
|
||||
data = self.get_shader_data()
|
||||
shader_info["raw_data"] = data.tobytes()
|
||||
shader_info["attributes"] = data.dtype.names
|
||||
shader_info["uniforms"] = self.get_shader_uniforms()
|
||||
return shader_info
|
||||
|
||||
|
|
|
|||
|
|
@ -847,6 +847,7 @@ class VMobject(Mobject):
|
|||
self.fill_data = np.zeros(len(self.points), dtype=self.fill_dtype)
|
||||
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
|
||||
self.fill_shader_info_template = get_shader_info(
|
||||
attributes=self.fill_data.dtype.names,
|
||||
vert_file=self.fill_vert_shader_file,
|
||||
geom_file=self.fill_geom_shader_file,
|
||||
frag_file=self.fill_frag_shader_file,
|
||||
|
|
@ -854,6 +855,7 @@ class VMobject(Mobject):
|
|||
render_primative=self.render_primative,
|
||||
)
|
||||
self.stroke_shader_info_template = get_shader_info(
|
||||
attributes=self.stroke_data.dtype.names,
|
||||
vert_file=self.stroke_vert_shader_file,
|
||||
geom_file=self.stroke_geom_shader_file,
|
||||
frag_file=self.stroke_frag_shader_file,
|
||||
|
|
@ -881,7 +883,7 @@ class VMobject(Mobject):
|
|||
fill_opacity = submob.get_fill_opacity()
|
||||
|
||||
if fill_opacity > 0:
|
||||
fill_data.append(submob.get_fill_shader_data())
|
||||
fill_data.append(submob.get_fill_shader_data().tobytes())
|
||||
|
||||
if stroke_width > 0 and stroke_opacity > 0:
|
||||
if submob.draw_stroke_behind_fill:
|
||||
|
|
@ -889,18 +891,18 @@ class VMobject(Mobject):
|
|||
else:
|
||||
data = stroke_data
|
||||
new_data = submob.get_stroke_shader_data()
|
||||
data.append(new_data)
|
||||
data.append(new_data.tobytes())
|
||||
|
||||
result = []
|
||||
if back_stroke_data:
|
||||
back_stroke_info = dict(stroke_info) # Copy
|
||||
back_stroke_info["data"] = np.hstack(back_stroke_data)
|
||||
back_stroke_info["raw_data"] = b''.join(back_stroke_data)
|
||||
result.append(back_stroke_info)
|
||||
if fill_data:
|
||||
fill_info["data"] = np.hstack(fill_data)
|
||||
fill_info["raw_data"] = b''.join(fill_data)
|
||||
result.append(fill_info)
|
||||
if stroke_data:
|
||||
stroke_info["data"] = np.hstack(stroke_data)
|
||||
stroke_info["raw_data"] = b''.join(stroke_data)
|
||||
result.append(stroke_info)
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -375,11 +375,10 @@ class Scene(Container):
|
|||
continue
|
||||
if mobject.get_family_updaters():
|
||||
continue
|
||||
mobject.lock_shader_data()
|
||||
self.camera.set_mobjects_as_static(mobject)
|
||||
|
||||
def unlock_mobject_data(self):
|
||||
for mobject in self.mobjects:
|
||||
mobject.unlock_shader_data()
|
||||
self.camera.release_static_mobjects()
|
||||
|
||||
def begin_animations(self, animations):
|
||||
for animation in animations:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from manimlib.constants import SHADER_DIR
|
|||
SHADER_INFO_KEYS = [
|
||||
# A structred array caring all of the points/color/lighting/etc. information
|
||||
# needed for the shader.
|
||||
"data",
|
||||
"raw_data",
|
||||
# List of variable names corresponding to inputs of vertex shader
|
||||
"attributes",
|
||||
# Filename of vetex shader
|
||||
"vert",
|
||||
# Filename of geometry shader, if there is one
|
||||
|
|
@ -33,11 +35,12 @@ SHADER_INFO_KEYS = [
|
|||
"render_primative",
|
||||
]
|
||||
|
||||
# Exclude data
|
||||
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[1:]
|
||||
# Exclude raw_data
|
||||
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[2:]
|
||||
|
||||
|
||||
def get_shader_info(data=None,
|
||||
def get_shader_info(raw_data=None,
|
||||
attributes=None,
|
||||
vert_file=None,
|
||||
geom_file=None,
|
||||
frag_file=None,
|
||||
|
|
@ -47,7 +50,8 @@ def get_shader_info(data=None,
|
|||
render_primative=moderngl.TRIANGLE_STRIP,
|
||||
):
|
||||
result = {
|
||||
"data": data,
|
||||
"raw_data": raw_data,
|
||||
"attributes": attributes,
|
||||
"vert": vert_file,
|
||||
"geom": geom_file,
|
||||
"frag": frag_file,
|
||||
|
|
@ -62,9 +66,9 @@ def get_shader_info(data=None,
|
|||
|
||||
|
||||
def is_valid_shader_info(shader_info):
|
||||
data = shader_info["data"]
|
||||
raw_data = shader_info["raw_data"]
|
||||
return all([
|
||||
data is not None and len(data) > 0,
|
||||
raw_data is not None and len(raw_data) > 0,
|
||||
shader_info["vert"],
|
||||
shader_info["frag"],
|
||||
])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue