mirror of
https://github.com/3b1b/manim.git
synced 2025-09-19 04:41:56 +00:00
Use an index buffer for shaders to save memory
This commit is contained in:
parent
978137b143
commit
2b3bd2bfce
5 changed files with 96 additions and 57 deletions
|
@ -331,9 +331,9 @@ class Camera(object):
|
|||
def render(self, shader_info):
|
||||
cached_buffers = "render_group" in shader_info
|
||||
if cached_buffers:
|
||||
vbo, vao, shader = shader_info["render_group"]
|
||||
vbo, ibo, vao, shader = shader_info["render_group"]
|
||||
else:
|
||||
vbo, vao, shader = self.get_render_group(shader_info)
|
||||
vbo, ibo, vao, shader = self.get_render_group(shader_info)
|
||||
|
||||
self.set_shader_uniforms(shader, shader_info)
|
||||
|
||||
|
@ -345,17 +345,25 @@ class Camera(object):
|
|||
vao.render(int(shader_info["render_primative"]))
|
||||
|
||||
if not cached_buffers:
|
||||
vbo.release()
|
||||
vao.release()
|
||||
self.release_gl_objects(vbo, ibo, vao)
|
||||
|
||||
def get_render_group(self, shader_info):
|
||||
shader, vert_format = self.get_shader(shader_info)
|
||||
vbo = self.ctx.buffer(shader_info["raw_data"])
|
||||
# vbo = self.ctx.buffer(shader_info["vert_data"].tobytes())
|
||||
vbo = self.ctx.buffer(shader_info["vert_data"])
|
||||
|
||||
vert_indices = shader_info["vert_indices"]
|
||||
if vert_indices is None:
|
||||
ibo = None
|
||||
else:
|
||||
ibo = self.ctx.buffer(vert_indices.astype('i4').tobytes())
|
||||
|
||||
vao = self.ctx.vertex_array(
|
||||
program=shader,
|
||||
content=[(vbo, vert_format, *shader_info["attributes"])]
|
||||
content=[(vbo, vert_format, *shader_info["attributes"])],
|
||||
index_buffer=ibo,
|
||||
)
|
||||
return (vbo, vao, shader)
|
||||
return (vbo, ibo, vao, shader)
|
||||
|
||||
def set_mobjects_as_static(self, *mobjects):
|
||||
# Create buffer and array objects holding each mobjects shader data
|
||||
|
@ -368,11 +376,14 @@ class Camera(object):
|
|||
def release_static_mobjects(self):
|
||||
for mob, info_list in self.static_mobjects_to_shader_info_list.items():
|
||||
for info in info_list:
|
||||
vbo, vao, shader = info["render_group"]
|
||||
vbo.release()
|
||||
vao.release()
|
||||
self.release_gl_objects(*info["render_group"][:3])
|
||||
self.static_mobjects_to_shader_info_list = {}
|
||||
|
||||
def release_gl_objects(self, *objs):
|
||||
for obj in objs:
|
||||
if obj:
|
||||
obj.release()
|
||||
|
||||
# Shaders
|
||||
def init_shaders(self):
|
||||
# Initialize with the null id going to None
|
||||
|
@ -387,7 +398,7 @@ class Camera(object):
|
|||
vert_format = moderngl.detect_format(program, shader_info["attributes"])
|
||||
self.id_to_shader[sid] = (program, vert_format)
|
||||
program, vert_format = self.id_to_shader[sid]
|
||||
self.set_shader_uniforms(program, shader_info)
|
||||
# self.set_shader_uniforms(program, shader_info)
|
||||
return program, vert_format
|
||||
|
||||
def set_shader_uniforms(self, shader, shader_info):
|
||||
|
|
|
@ -1218,7 +1218,9 @@ class Mobject(Container):
|
|||
# For shaders
|
||||
def init_shader_data(self):
|
||||
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
|
||||
self.shader_indices = None
|
||||
self.shader_info_template = get_shader_info(
|
||||
attributes=self.shader_data.dtype.names,
|
||||
vert_file=self.vert_shader_file,
|
||||
geom_file=self.geom_shader_file,
|
||||
frag_file=self.frag_shader_file,
|
||||
|
@ -1250,17 +1252,29 @@ class Mobject(Container):
|
|||
|
||||
result = []
|
||||
for info_group, sid in batches:
|
||||
shader_info = info_group[0]
|
||||
shader_info["raw_data"] = b''.join([info["raw_data"] for info in info_group])
|
||||
if is_valid_shader_info(shader_info):
|
||||
result.append(shader_info)
|
||||
combined_info = info_group[0]
|
||||
if not is_valid_shader_info(combined_info):
|
||||
continue
|
||||
data_list = []
|
||||
indices_list = []
|
||||
num_verts = 0
|
||||
for info in info_group:
|
||||
data_list.append(info["vert_data"])
|
||||
if info["vert_indices"] is not None:
|
||||
indices_list.append(info["vert_indices"] + num_verts)
|
||||
num_verts += len(info["vert_data"])
|
||||
# Combine lists
|
||||
combined_info["vert_data"] = np.hstack(data_list)
|
||||
if combined_info["vert_indices"] is not None:
|
||||
combined_info["vert_indices"] = np.hstack(indices_list)
|
||||
if len(combined_info["vert_indices"]) > 0:
|
||||
result.append(combined_info)
|
||||
return result
|
||||
|
||||
def get_shader_info(self):
|
||||
shader_info = dict(self.shader_info_template)
|
||||
data = self.get_shader_data()
|
||||
shader_info["raw_data"] = data.tobytes()
|
||||
shader_info["attributes"] = data.dtype.names
|
||||
shader_info["vert_data"] = self.get_shader_data()
|
||||
shader_info["vert_indices"] = self.get_shader_vert_indices()
|
||||
shader_info["uniforms"] = self.get_shader_uniforms()
|
||||
return shader_info
|
||||
|
||||
|
@ -1276,6 +1290,9 @@ class Mobject(Container):
|
|||
# Must return a structured numpy array
|
||||
return self.shader_data
|
||||
|
||||
def get_shader_vert_indices(self):
|
||||
return self.shader_indices
|
||||
|
||||
# Errors
|
||||
def throw_error_if_no_points(self):
|
||||
if self.has_no_points():
|
||||
|
|
|
@ -70,8 +70,9 @@ class ParametricSurface(Mobject):
|
|||
# the resolution of the surface, make sure
|
||||
# this is called.
|
||||
nu, nv = self.resolution
|
||||
if nu == 0 and nv == 0:
|
||||
return np.zeros(0, dtype=int)
|
||||
if nu == 0 or nv == 0:
|
||||
self.triangle_indices = np.zeros(0, dtype=int)
|
||||
return
|
||||
index_grid = np.arange(nu * nv).reshape((nu, nv))
|
||||
indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int)
|
||||
indices[0::6] = index_grid[:-1, :-1].flatten() # Top left
|
||||
|
@ -124,13 +125,10 @@ class ParametricSurface(Mobject):
|
|||
|
||||
def get_shader_data(self):
|
||||
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
|
||||
tri_indices = self.get_triangle_indices()
|
||||
data = self.get_blank_shader_data_array(len(tri_indices))
|
||||
if len(tri_indices) == 0:
|
||||
return data
|
||||
data["point"] = s_points[tri_indices]
|
||||
data["du_point"] = du_points[tri_indices]
|
||||
data["dv_point"] = dv_points[tri_indices]
|
||||
data = self.get_blank_shader_data_array(len(s_points))
|
||||
data["point"] = s_points
|
||||
data["du_point"] = du_points
|
||||
data["dv_point"] = dv_points
|
||||
self.fill_in_shader_color_info(data)
|
||||
return data
|
||||
|
||||
|
@ -138,6 +136,9 @@ class ParametricSurface(Mobject):
|
|||
data["color"] = self.rgbas
|
||||
return data
|
||||
|
||||
def get_shader_vert_indices(self):
|
||||
return self.get_triangle_indices()
|
||||
|
||||
|
||||
class SGroup(ParametricSurface):
|
||||
CONFIG = {
|
||||
|
@ -145,15 +146,11 @@ class SGroup(ParametricSurface):
|
|||
}
|
||||
|
||||
def __init__(self, *parametric_surfaces, **kwargs):
|
||||
# TODO, separate out the surface type...again
|
||||
super().__init__(uv_func=None, **kwargs)
|
||||
self.add(*parametric_surfaces)
|
||||
|
||||
def init_points(self):
|
||||
pass
|
||||
|
||||
def get_triangle_indices(self):
|
||||
return np.zeros(0)
|
||||
self.points = np.zeros((0, 3))
|
||||
|
||||
|
||||
class TexturedSurface(ParametricSurface):
|
||||
|
|
|
@ -912,31 +912,39 @@ class VMobject(Mobject):
|
|||
for info in fill_info, stroke_info:
|
||||
info["depth_test"] = self.depth_test
|
||||
|
||||
# Build up data lists
|
||||
back_stroke_data = []
|
||||
stroke_data = []
|
||||
fill_data = []
|
||||
fill_vert_indices = []
|
||||
num_fill_verts = 0 # Number of fill verts
|
||||
for submob in self.family_members_with_points():
|
||||
if submob.has_fill():
|
||||
data = submob.get_fill_shader_data().tobytes()
|
||||
fill_data.append(data)
|
||||
if submob.has_stroke():
|
||||
if submob.draw_stroke_behind_fill:
|
||||
data = back_stroke_data
|
||||
else:
|
||||
data = stroke_data
|
||||
new_data = submob.get_stroke_shader_data()
|
||||
data.append(new_data.tobytes())
|
||||
data = submob.get_fill_shader_data()
|
||||
indices = submob.get_fill_shader_vert_indices() + num_fill_verts
|
||||
num_fill_verts += len(data)
|
||||
|
||||
fill_data.append(data)
|
||||
fill_vert_indices.append(indices)
|
||||
if submob.has_stroke():
|
||||
data = submob.get_stroke_shader_data()
|
||||
if submob.draw_stroke_behind_fill:
|
||||
back_stroke_data.append(data)
|
||||
else:
|
||||
stroke_data.append(data)
|
||||
|
||||
# Combine data lists
|
||||
result = []
|
||||
if back_stroke_data:
|
||||
back_stroke_info = dict(stroke_info) # Copy
|
||||
back_stroke_info["raw_data"] = b''.join(back_stroke_data)
|
||||
back_stroke_info["vert_data"] = np.hstack(back_stroke_data)
|
||||
result.append(back_stroke_info)
|
||||
if fill_data:
|
||||
fill_info["raw_data"] = b''.join(fill_data)
|
||||
fill_info["vert_data"] = np.hstack(fill_data)
|
||||
fill_info["vert_indices"] = np.hstack(fill_vert_indices)
|
||||
result.append(fill_info)
|
||||
if stroke_data:
|
||||
stroke_info["raw_data"] = b''.join(stroke_data)
|
||||
stroke_info["vert_data"] = np.hstack(stroke_data)
|
||||
result.append(stroke_info)
|
||||
return result
|
||||
|
||||
|
@ -1046,19 +1054,22 @@ class VMobject(Mobject):
|
|||
|
||||
def get_fill_shader_data(self):
|
||||
points = self.points
|
||||
n_points = len(points)
|
||||
unit_normal = self.get_unit_normal()
|
||||
tri_indices = self.get_triangulation(unit_normal)
|
||||
|
||||
# TODO, best way to enable multiple colors?
|
||||
rgbas = self.get_fill_rgbas()[:1]
|
||||
|
||||
data = self.get_blank_shader_data_array(len(tri_indices), "fill_data")
|
||||
data["point"] = points[tri_indices]
|
||||
data = self.get_blank_shader_data_array(n_points, "fill_data")
|
||||
data["point"] = points
|
||||
data["unit_normal"] = unit_normal
|
||||
data["color"] = rgbas
|
||||
data["vert_index"][:, 0] = tri_indices
|
||||
data["vert_index"][:, 0] = self.saved_triangulation[:n_points]
|
||||
return data
|
||||
|
||||
def get_fill_shader_vert_indices(self):
|
||||
return self.get_triangulation()
|
||||
|
||||
|
||||
class VGroup(VMobject):
|
||||
def __init__(self, *vmobjects, **kwargs):
|
||||
|
|
|
@ -22,9 +22,10 @@ from manimlib.constants import SHADER_DIR
|
|||
|
||||
|
||||
SHADER_INFO_KEYS = [
|
||||
# A structred array caring all of the points/color/lighting/etc. information
|
||||
# needed for the shader.
|
||||
"raw_data",
|
||||
# Vertex data for the shader (as structured array)
|
||||
"vert_data",
|
||||
# Index data (if applicable) for the shader
|
||||
"index_data",
|
||||
# List of variable names corresponding to inputs of vertex shader
|
||||
"attributes",
|
||||
# Filename of vetex shader
|
||||
|
@ -44,11 +45,12 @@ SHADER_INFO_KEYS = [
|
|||
"render_primative",
|
||||
]
|
||||
|
||||
# Exclude raw_data
|
||||
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[2:]
|
||||
# Exclude data
|
||||
SHADER_KEYS_FOR_ID = SHADER_INFO_KEYS[3:]
|
||||
|
||||
|
||||
def get_shader_info(raw_data=None,
|
||||
def get_shader_info(vert_data=None,
|
||||
vert_indices=None,
|
||||
attributes=None,
|
||||
vert_file=None,
|
||||
geom_file=None,
|
||||
|
@ -59,7 +61,8 @@ def get_shader_info(raw_data=None,
|
|||
render_primative=moderngl.TRIANGLE_STRIP,
|
||||
):
|
||||
result = {
|
||||
"raw_data": raw_data,
|
||||
"vert_data": vert_data,
|
||||
"vert_indices": vert_indices,
|
||||
"attributes": attributes,
|
||||
"vert": vert_file,
|
||||
"geom": geom_file,
|
||||
|
@ -75,9 +78,9 @@ def get_shader_info(raw_data=None,
|
|||
|
||||
|
||||
def is_valid_shader_info(shader_info):
|
||||
raw_data = shader_info["raw_data"]
|
||||
vert_data = shader_info["vert_data"]
|
||||
return all([
|
||||
raw_data is not None and len(raw_data) > 0,
|
||||
vert_data is not None,
|
||||
shader_info["vert"],
|
||||
shader_info["frag"],
|
||||
])
|
||||
|
|
Loading…
Add table
Reference in a new issue