Use an index buffer for shaders to save memory

This commit is contained in:
Grant Sanderson 2020-06-29 11:05:09 -07:00
parent 978137b143
commit 2b3bd2bfce
5 changed files with 96 additions and 57 deletions

View file

@ -331,9 +331,9 @@ class Camera(object):
def render(self, shader_info):
cached_buffers = "render_group" in shader_info
if cached_buffers:
vbo, vao, shader = shader_info["render_group"]
vbo, ibo, vao, shader = shader_info["render_group"]
else:
vbo, vao, shader = self.get_render_group(shader_info)
vbo, ibo, vao, shader = self.get_render_group(shader_info)
self.set_shader_uniforms(shader, shader_info)
@ -345,17 +345,25 @@ class Camera(object):
vao.render(int(shader_info["render_primative"]))
if not cached_buffers:
vbo.release()
vao.release()
self.release_gl_objects(vbo, ibo, vao)
def get_render_group(self, shader_info):
shader, vert_format = self.get_shader(shader_info)
vbo = self.ctx.buffer(shader_info["raw_data"])
# vbo = self.ctx.buffer(shader_info["vert_data"].tobytes())
vbo = self.ctx.buffer(shader_info["vert_data"])
vert_indices = shader_info["vert_indices"]
if vert_indices is None:
ibo = None
else:
ibo = self.ctx.buffer(vert_indices.astype('i4').tobytes())
vao = self.ctx.vertex_array(
program=shader,
content=[(vbo, vert_format, *shader_info["attributes"])]
content=[(vbo, vert_format, *shader_info["attributes"])],
index_buffer=ibo,
)
return (vbo, vao, shader)
return (vbo, ibo, vao, shader)
def set_mobjects_as_static(self, *mobjects):
# Create buffer and array objects holding each mobjects shader data
@ -368,11 +376,14 @@ class Camera(object):
def release_static_mobjects(self):
for mob, info_list in self.static_mobjects_to_shader_info_list.items():
for info in info_list:
vbo, vao, shader = info["render_group"]
vbo.release()
vao.release()
self.release_gl_objects(*info["render_group"][:3])
self.static_mobjects_to_shader_info_list = {}
def release_gl_objects(self, *objs):
for obj in objs:
if obj:
obj.release()
# Shaders
def init_shaders(self):
# Initialize with the null id going to None
@ -387,7 +398,7 @@ class Camera(object):
vert_format = moderngl.detect_format(program, shader_info["attributes"])
self.id_to_shader[sid] = (program, vert_format)
program, vert_format = self.id_to_shader[sid]
self.set_shader_uniforms(program, shader_info)
# self.set_shader_uniforms(program, shader_info)
return program, vert_format
def set_shader_uniforms(self, shader, shader_info):

View file

@ -1218,7 +1218,9 @@ class Mobject(Container):
# For shaders
def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
self.shader_indices = None
self.shader_info_template = get_shader_info(
attributes=self.shader_data.dtype.names,
vert_file=self.vert_shader_file,
geom_file=self.geom_shader_file,
frag_file=self.frag_shader_file,
@ -1250,17 +1252,29 @@ class Mobject(Container):
result = []
for info_group, sid in batches:
shader_info = info_group[0]
shader_info["raw_data"] = b''.join([info["raw_data"] for info in info_group])
if is_valid_shader_info(shader_info):
result.append(shader_info)
combined_info = info_group[0]
if not is_valid_shader_info(combined_info):
continue
data_list = []
indices_list = []
num_verts = 0
for info in info_group:
data_list.append(info["vert_data"])
if info["vert_indices"] is not None:
indices_list.append(info["vert_indices"] + num_verts)
num_verts += len(info["vert_data"])
# Combine lists
combined_info["vert_data"] = np.hstack(data_list)
if combined_info["vert_indices"] is not None:
combined_info["vert_indices"] = np.hstack(indices_list)
if len(combined_info["vert_indices"]) > 0:
result.append(combined_info)
return result
def get_shader_info(self):
shader_info = dict(self.shader_info_template)
data = self.get_shader_data()
shader_info["raw_data"] = data.tobytes()
shader_info["attributes"] = data.dtype.names
shader_info["vert_data"] = self.get_shader_data()
shader_info["vert_indices"] = self.get_shader_vert_indices()
shader_info["uniforms"] = self.get_shader_uniforms()
return shader_info
@ -1276,6 +1290,9 @@ class Mobject(Container):
# Must return a structured numpy array
return self.shader_data
def get_shader_vert_indices(self):
return self.shader_indices
# Errors
def throw_error_if_no_points(self):
if self.has_no_points():

View file

@ -70,8 +70,9 @@ class ParametricSurface(Mobject):
# the resolution of the surface, make sure
# this is called.
nu, nv = self.resolution
if nu == 0 and nv == 0:
return np.zeros(0, dtype=int)
if nu == 0 or nv == 0:
self.triangle_indices = np.zeros(0, dtype=int)
return
index_grid = np.arange(nu * nv).reshape((nu, nv))
indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int)
indices[0::6] = index_grid[:-1, :-1].flatten() # Top left
@ -124,13 +125,10 @@ class ParametricSurface(Mobject):
def get_shader_data(self):
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
tri_indices = self.get_triangle_indices()
data = self.get_blank_shader_data_array(len(tri_indices))
if len(tri_indices) == 0:
return data
data["point"] = s_points[tri_indices]
data["du_point"] = du_points[tri_indices]
data["dv_point"] = dv_points[tri_indices]
data = self.get_blank_shader_data_array(len(s_points))
data["point"] = s_points
data["du_point"] = du_points
data["dv_point"] = dv_points
self.fill_in_shader_color_info(data)
return data
@ -138,6 +136,9 @@ class ParametricSurface(Mobject):
data["color"] = self.rgbas
return data
def get_shader_vert_indices(self):
return self.get_triangle_indices()
class SGroup(ParametricSurface):
CONFIG = {
@ -145,15 +146,11 @@ class SGroup(ParametricSurface):
}
def __init__(self, *parametric_surfaces, **kwargs):
# TODO, separate out the surface type...again
super().__init__(uv_func=None, **kwargs)
self.add(*parametric_surfaces)
def init_points(self):
pass
def get_triangle_indices(self):
return np.zeros(0)
self.points = np.zeros((0, 3))
class TexturedSurface(ParametricSurface):

View file

@ -912,31 +912,39 @@ class VMobject(Mobject):
for info in fill_info, stroke_info:
info["depth_test"] = self.depth_test
# Build up data lists
back_stroke_data = []
stroke_data = []
fill_data = []
fill_vert_indices = []
num_fill_verts = 0 # Number of fill verts
for submob in self.family_members_with_points():
if submob.has_fill():
data = submob.get_fill_shader_data().tobytes()
fill_data.append(data)
if submob.has_stroke():
if submob.draw_stroke_behind_fill:
data = back_stroke_data
else:
data = stroke_data
new_data = submob.get_stroke_shader_data()
data.append(new_data.tobytes())
data = submob.get_fill_shader_data()
indices = submob.get_fill_shader_vert_indices() + num_fill_verts
num_fill_verts += len(data)
fill_data.append(data)
fill_vert_indices.append(indices)
if submob.has_stroke():
data = submob.get_stroke_shader_data()
if submob.draw_stroke_behind_fill:
back_stroke_data.append(data)
else:
stroke_data.append(data)
# Combine data lists
result = []
if back_stroke_data:
back_stroke_info = dict(stroke_info) # Copy
back_stroke_info["raw_data"] = b''.join(back_stroke_data)
back_stroke_info["vert_data"] = np.hstack(back_stroke_data)
result.append(back_stroke_info)
if fill_data:
fill_info["raw_data"] = b''.join(fill_data)
fill_info["vert_data"] = np.hstack(fill_data)
fill_info["vert_indices"] = np.hstack(fill_vert_indices)
result.append(fill_info)
if stroke_data:
stroke_info["raw_data"] = b''.join(stroke_data)
stroke_info["vert_data"] = np.hstack(stroke_data)
result.append(stroke_info)
return result
@ -1046,19 +1054,22 @@ class VMobject(Mobject):
def get_fill_shader_data(self):
points = self.points
n_points = len(points)
unit_normal = self.get_unit_normal()
tri_indices = self.get_triangulation(unit_normal)
# TODO, best way to enable multiple colors?
rgbas = self.get_fill_rgbas()[:1]
data = self.get_blank_shader_data_array(len(tri_indices), "fill_data")
data["point"] = points[tri_indices]
data = self.get_blank_shader_data_array(n_points, "fill_data")
data["point"] = points
data["unit_normal"] = unit_normal
data["color"] = rgbas
data["vert_index"][:, 0] = tri_indices
data["vert_index"][:, 0] = self.saved_triangulation[:n_points]
return data
def get_fill_shader_vert_indices(self):
return self.get_triangulation()
class VGroup(VMobject):
def __init__(self, *vmobjects, **kwargs):

View file

@ -22,9 +22,10 @@ from manimlib.constants import SHADER_DIR
SHADER_INFO_KEYS = [
# A structred array caring all of the points/color/lighting/etc. information
# needed for the shader.
"raw_data",
# Vertex data for the shader (as structured array)
"vert_data",
# Index data (if applicable) for the shader
"index_data",
# List of variable names corresponding to inputs of vertex shader
"attributes",
# Filename of vetex shader
@ -44,11 +45,12 @@ SHADER_INFO_KEYS = [
"render_primative",
]
# Exclude raw_data
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[2:]
# Exclude data
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[3:]
def get_shader_info(raw_data=None,
def get_shader_info(vert_data=None,
vert_indices=None,
attributes=None,
vert_file=None,
geom_file=None,
@ -59,7 +61,8 @@ def get_shader_info(raw_data=None,
render_primative=moderngl.TRIANGLE_STRIP,
):
result = {
"raw_data": raw_data,
"vert_data": vert_data,
"vert_indices": vert_indices,
"attributes": attributes,
"vert": vert_file,
"geom": geom_file,
@ -75,9 +78,9 @@ def get_shader_info(raw_data=None,
def is_valid_shader_info(shader_info):
raw_data = shader_info["raw_data"]
vert_data = shader_info["vert_data"]
return all([
raw_data is not None and len(raw_data) > 0,
vert_data is not None,
shader_info["vert"],
shader_info["frag"],
])