Update VMobject shader wrapper

Use a combined VBO
Render with TRIANGLE_STRIP, and ignore every other
This commit is contained in:
Grant Sanderson 2024-08-19 08:05:32 -05:00
parent f9b9cf69fd
commit 24b160f9f9
7 changed files with 213 additions and 191 deletions

View file

@ -2049,8 +2049,6 @@ class Mobject(object):
def render(self, ctx: Context, camera_uniforms: dict):
if self._data_has_changed:
self.shader_wrappers = self.get_shader_wrapper_list(ctx)
for shader_wrapper in self.shader_wrappers:
shader_wrapper.load_data()
self._data_has_changed = False
for shader_wrapper in self.shader_wrappers:
shader_wrapper.update_program_uniforms(camera_uniforms)

View file

@ -5,6 +5,7 @@ from functools import wraps
import moderngl
import numpy as np
import operator as op
import itertools as it
from manimlib.constants import GREY_A, GREY_C, GREY_E
from manimlib.constants import BLACK
@ -46,7 +47,7 @@ from manimlib.utils.space_ops import rotation_between_vectors
from manimlib.utils.space_ops import poly_line_length
from manimlib.utils.space_ops import z_to_vector
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper
from manimlib.shader_wrapper import VShaderWrapper
from typing import TYPE_CHECKING
from typing import Generic, TypeVar, Iterable
@ -74,8 +75,8 @@ class VMobject(Mobject):
('unit_normal', np.float32, (3,)),
('fill_border_width', np.float32, (1,)),
])
fill_data_names = ['point', 'fill_rgba', 'base_point', 'unit_normal']
stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_product']
fill_data_names = ['point', 'fill_rgba', 'base_point', 'unit_normal'] # Delete these
stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_product'] # Delete these
fill_render_primitive: int = moderngl.TRIANGLES
stroke_render_primitive: int = moderngl.TRIANGLES
@ -1325,49 +1326,17 @@ class VMobject(Mobject):
return self
# For shaders
def init_shader_data(self, ctx: Context):
dtype = self.shader_dtype
fill_dtype, stroke_dtype = (
np.dtype([
(name, dtype[name].base, dtype[name].shape)
for name in names
])
for names in [self.fill_data_names, self.stroke_data_names]
)
fill_data = np.zeros(0, dtype=fill_dtype)
stroke_data = np.zeros(0, dtype=stroke_dtype)
self.fill_shader_wrapper = FillShaderWrapper(
ctx=ctx,
vert_data=fill_data,
mobject_uniforms=self.uniforms,
shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive,
)
self.stroke_shader_wrapper = ShaderWrapper(
ctx=ctx,
vert_data=stroke_data,
mobject_uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive,
)
self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy()
self.shader_wrappers = [
self.back_stroke_shader_wrapper,
self.fill_shader_wrapper,
self.stroke_shader_wrapper,
]
for sw in self.shader_wrappers:
family = self.family_members_with_points()
rep = family[0] if family else self
for old, new in rep.shader_code_replacements.items():
sw.replace_code(old, new)
def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized:
return self
for wrapper in self.shader_wrappers:
wrapper.refresh_id()
return self
def init_shader_data(self, ctx: Context):
self.shader_indices = np.zeros(0)
self.shader_wrapper = VShaderWrapper(
ctx=ctx,
vert_data=self.data,
mobject_uniforms=self.uniforms,
)
def get_shader_vert_indices(self):
return self.get_outer_vert_indices()
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized:
@ -1377,62 +1346,20 @@ class VMobject(Mobject):
family = self.family_members_with_points()
if not family:
return []
fill_names = self.fill_data_names
stroke_names = self.stroke_data_names
fill_family = (sm for sm in family if sm._has_fill)
stroke_family = (sm for sm in family if sm._has_stroke)
for submob in family:
if submob._has_fill:
submob.data["base_point"] = submob.data["point"][0]
# Build up fill data lists
fill_datas = []
fill_indices = []
fill_border_datas = []
for submob in fill_family:
indices = submob.get_outer_vert_indices()
if submob._use_winding_fill:
data = submob.data[fill_names]
data["base_point"][:] = data["point"][0]
fill_datas.append(data[indices])
else:
fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation())
draw_border_width = op.and_(
submob.data['fill_border_width'][0] > 0,
(not submob._has_stroke) or submob.stroke_behind,
)
if draw_border_width:
# Add fill border
submob.get_joint_products()
names = list(stroke_names)
names[names.index('stroke_rgba')] = 'fill_rgba'
names[names.index('stroke_width')] = 'fill_border_width'
border_stroke_data = submob.data[names].astype(
self.stroke_shader_wrapper.vert_data.dtype
)
fill_border_datas.append(border_stroke_data[indices])
# Build up stroke data lists
stroke_datas = []
back_stroke_datas = []
for submob in stroke_family:
submob.get_joint_products()
indices = submob.get_outer_vert_indices()
if submob.stroke_behind:
back_stroke_datas.append(submob.data[stroke_names][indices])
else:
stroke_datas.append(submob.data[stroke_names][indices])
shader_wrappers = [
self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]),
self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None),
self.stroke_shader_wrapper.read_in(stroke_datas),
]
for sw in shader_wrappers:
rep = family[0] # Representative family member
sw.bind_to_mobject_uniforms(rep.get_uniforms())
sw.depth_test = rep.depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
self.shader_wrapper.read_in(
# [sm.data for sm in family],
list(it.chain(*([sm.data, sm.data[-1:]] for sm in family)))
# [sm.get_shader_vert_indices() for sm in family]
)
rep = family[0] # Representative family member
self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms())
self.shader_wrapper.depth_test = rep.depth_test
return [self.shader_wrapper]
class VGroup(Group, VMobject, Generic[SubVmobjectType]):

View file

@ -36,7 +36,6 @@ class ShaderWrapper(object):
self,
ctx: moderngl.context.Context,
vert_data: np.ndarray,
vert_indices: Optional[np.ndarray] = None,
shader_folder: Optional[str] = None,
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
@ -45,7 +44,6 @@ class ShaderWrapper(object):
):
self.ctx = ctx
self.vert_data = vert_data
self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder
self.depth_test = depth_test
@ -59,7 +57,7 @@ class ShaderWrapper(object):
self.texture_names_to_ids = dict()
if texture_paths is not None:
self.init_textures(texture_paths)
self.init_vao()
self.init_vertex_objects()
self.refresh_id()
def init_program_code(self) -> None:
@ -81,6 +79,7 @@ class ShaderWrapper(object):
return
self.program = get_shader_program(self.ctx, **self.program_code)
self.vert_format = moderngl.detect_format(self.program, self.vert_attributes)
self.programs = [self.program]
def init_textures(self, texture_paths: dict[str, str]):
self.texture_names_to_ids = {
@ -88,9 +87,8 @@ class ShaderWrapper(object):
for name, path in texture_paths.items()
}
def init_vao(self):
def init_vertex_objects(self):
self.vbo = None
self.ibo = None
self.vao = None
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
@ -98,8 +96,7 @@ class ShaderWrapper(object):
def __eq__(self, shader_wrapper: ShaderWrapper):
return all((
np.all(self.vert_data == shader_wrapper.vert_data),
np.all(self.vert_indices == shader_wrapper.vert_indices),
# np.all(self.vert_data == shader_wrapper.vert_data),
self.shader_folder == shader_wrapper.shader_folder,
all(
self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key]
@ -113,8 +110,7 @@ class ShaderWrapper(object):
result = copy.copy(self)
result.ctx = self.ctx
result.vert_data = self.vert_data.copy()
result.vert_indices = self.vert_indices.copy()
result.init_vao()
result.init_vertex_objects()
return result
def is_valid(self) -> bool:
@ -129,10 +125,7 @@ class ShaderWrapper(object):
def create_id(self) -> str:
# A unique id for a shader
program_id = hash("".join(
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
))
program_id = hash("".join(map(str, self.program_code.values())))
return "|".join(map(str, [
program_id,
self.mobject_uniforms,
@ -173,45 +166,52 @@ class ShaderWrapper(object):
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
if len(shader_wrappers) > 0:
data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)]
indices_list = [self.vert_indices, *(sw.vert_indices for sw in shader_wrappers)]
self.read_in(data_list, indices_list)
self.read_in([self.vert_data, (sw.vert_data for sw in shader_wrappers)])
vbos = [
vbo
for vbo in [self.vbo, *(sw.vbo for sw in shader_wrappers)]
if vbo is not None
]
total_size = sum(vbo.size for vbo in vbos)
new_vbo = self.ctx.buffer(reserve=total_size)
offset = 0
for vbo in vbos:
new_vbo.write(vbo.read(), offset=offset)
offset += vbo.size
self.vbo = new_vbo
return self
def read_in(
self,
data_list: List[np.ndarray],
indices_list: List[np.ndarray] | None = None
) -> ShaderWrapper:
# Assume all are of the same type
total_len = sum(len(data) for data in data_list)
self.vert_data = resize_array(self.vert_data, total_len)
data_list: Iterable[np.ndarray],
indices_list: Iterable[np.ndarray] | None = None
):
if indices_list is not None:
data_list = [data[indices] for data, indices in zip(data_list, indices_list)]
total_len = sum(map(len, indices_list))
else:
total_len = sum(map(len, data_list))
if total_len == 0:
return self
if self.vbo is not None:
self.vbo.clear()
return
# Stack the data
np.concatenate(data_list, out=self.vert_data)
# If possible, read concatenated data into existing list
if len(self.vert_data) != total_len:
self.vert_data = np.concatenate(data_list)
else:
np.concatenate(data_list, out=self.vert_data)
if indices_list is None:
self.vert_indices = resize_array(self.vert_indices, 0)
return self
total_verts = sum(len(vi) for vi in indices_list)
if total_verts == 0:
return self
self.vert_indices = resize_array(self.vert_indices, total_verts)
# Stack vert_indices, but adding the appropriate offset
# alogn the way
n_points = 0
n_verts = 0
for data, indices in zip(data_list, indices_list):
new_n_verts = n_verts + len(indices)
self.vert_indices[n_verts:new_n_verts] = indices + n_points
n_verts = new_n_verts
n_points += len(data)
return self
# Either create new vbo, or read data into it
total_size = self.vert_data.itemsize * total_len
if self.vbo is None:
self.vbo = self.ctx.buffer(self.vert_data)
elif self.vbo.size != total_size:
self.vbo.release()
self.vbo = self.ctx.buffer(self.vert_data)
else:
self.vbo.write(self.vert_data)
# Related to data and rendering
def pre_render(self):
@ -219,74 +219,125 @@ class ShaderWrapper(object):
self.set_ctx_clip_plane(self.use_clip_plane())
def render(self):
assert self.vao is not None
if self.vao is None:
self.generate_vao()
self.vao.render()
def update_program_uniforms(self, camera_uniforms: UniformDict):
if self.program is None:
return
for uniforms in [self.mobject_uniforms, camera_uniforms, self.texture_names_to_ids]:
for name, value in uniforms.items():
set_program_uniform(self.program, name, value)
def get_vertex_buffer_object(self):
self.vbo = self.ctx.buffer(self.vert_data)
return self.vbo
def get_index_buffer_object(self):
if len(self.vert_indices) > 0:
self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32))
return self.ibo
def load_data(self):
if self.vao is None:
self.generate_vao()
elif self.vao.vertices != len(self.vert_data):
self.release()
self.generate_vao()
else:
self.vbo.write(self.vert_data)
if self.ibo is not None:
self.ibo.write(self.self.vert_indices.astype(np.uint32))
for program in self.programs:
if program is None:
continue
for uniforms in [self.mobject_uniforms, camera_uniforms, self.texture_names_to_ids]:
for name, value in uniforms.items():
set_program_uniform(program, name, value)
def generate_vao(self):
# Data buffer
vbo = self.get_vertex_buffer_object()
ibo = self.get_index_buffer_object()
if self.vbo is None:
self.vbo = self.ctx.buffer(self.vert_data)
# Vertex array object
self.vao = self.ctx.vertex_array(
program=self.program,
content=[(vbo, self.vert_format, *self.vert_attributes)],
index_buffer=ibo,
content=[(self.vbo, self.vert_format, *self.vert_attributes)],
mode=self.render_primitive,
)
return self.vao
def release(self):
for obj in (self.vbo, self.ibo, self.vao):
for obj in (self.vbo, self.vao):
if obj is not None:
obj.release()
self.vbo = None
self.ibo = None
self.vao = None
class FillShaderWrapper(ShaderWrapper):
class VShaderWrapper(ShaderWrapper):
def __init__(
self,
ctx: moderngl.context.Context,
*args,
**kwargs
vert_data: np.ndarray,
shader_folder: Optional[str] = None,
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
# render_primitive: int = moderngl.TRIANGLES,
render_primitive: int = moderngl.TRIANGLE_STRIP,
):
super().__init__(ctx, *args, **kwargs)
super().__init__(
ctx=ctx,
vert_data=vert_data,
shader_folder=shader_folder,
mobject_uniforms=mobject_uniforms,
texture_paths=texture_paths,
depth_test=depth_test,
render_primitive=render_primitive,
)
self.fill_canvas = get_fill_canvas(self.ctx)
def render(self):
winding = (len(self.vert_indices) == 0)
self.program['winding'].value = winding
def init_program_code(self) -> None:
self.program_code = {
f"{vtype}_{name}": get_shader_code_from_file(
os.path.join(f"quadratic_bezier_{vtype}", f"{name}.glsl")
)
for vtype in ["stroke", "fill"]
for name in ["vert", "geom", "frag"]
}
def init_program(self):
self.stroke_program = get_shader_program(
self.ctx,
vertex_shader=self.program_code["stroke_vert"],
geometry_shader=self.program_code["stroke_geom"],
fragment_shader=self.program_code["stroke_frag"],
)
self.fill_program = get_shader_program(
self.ctx,
vertex_shader=self.program_code["fill_vert"],
geometry_shader=self.program_code["fill_geom"],
fragment_shader=self.program_code["fill_frag"],
)
self.programs = [self.stroke_program, self.fill_program]
# Full vert format looks like this (total of 4x23 = 92 bytes):
# point 3
# stroke_rgba 4
# stroke_width 1
# joint_product 4
# fill_rgba 4
# base_point 3
# unit_normal 3
# fill_border_width 1
self.stroke_vert_format = '3f 4f 1f 4f 44x'
self.stroke_vert_attributes = ['point', 'stroke_rgba', 'stroke_width', 'joint_product']
self.fill_vert_format = '3f 36x 4f 3f 3f 4x'
self.fill_vert_attributes = ['point', 'fill_rgba', 'base_point', 'unit_normal']
def init_vertex_objects(self):
self.vbo = None
self.stroke_vao = None
self.fill_vao = None
# TODO, think about create_id, replace_code
def is_valid(self) -> bool:
return self.vert_data is not None
# TODO, motidify read in to handle triangulation case for non-winding fill?
# Rendering
def render_stroke(self):
if self.stroke_vao is None:
return
self.stroke_vao.render()
def render_fill(self):
if self.fill_vao is None:
return
# TODO, need a new test here
winding = True
self.fill_program['winding'].value = winding
if not winding:
super().render()
self.fill_vao.render()
return
original_fbo = self.ctx.fbo
@ -302,7 +353,7 @@ class FillShaderWrapper(ShaderWrapper):
gl.GL_ONE_MINUS_DST_ALPHA, gl.GL_ONE,
)
super().render()
self.fill_vao.render()
original_fbo.use()
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA)
@ -310,3 +361,29 @@ class FillShaderWrapper(ShaderWrapper):
texture_vao.render()
gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
def render(self):
if self.stroke_vao is None or self.fill_vao is None:
self.generate_vao()
self.render_fill()
self.render_stroke()
def generate_vao(self):
self.stroke_vao = self.ctx.vertex_array(
program=self.stroke_program,
content=[(self.vbo, self.stroke_vert_format, *self.stroke_vert_attributes)],
mode=self.render_primitive,
)
self.fill_vao = self.ctx.vertex_array(
program=self.fill_program,
content=[(self.vbo, self.fill_vert_format, *self.fill_vert_attributes)],
mode=self.render_primitive,
)
def release(self):
attrs = ["vbo", "stroke_vao", "fill_vao"]
for attr in attrs:
obj = getattr(self, attr)
if obj is not None:
obj.release()
setattr(self, attr, None)

View file

@ -8,8 +8,8 @@ uniform bool winding;
in vec3 verts[3];
in vec4 v_color[3];
in vec3 v_base_point[3];
in float v_vert_index[3];
in vec3 v_unit_normal[3];
in int v_vert_index[3];
out vec4 color;
out float fill_all;
@ -57,10 +57,18 @@ void emit_simple_triangle(){
void main(){
// Vector graphic shaders use TRIANGLE_STRIP, but only
// every other one needs to be rendered
if (v_vert_index[0] % 2 != 0) return;
// Curves are marked as ended when the handle after
// the first anchor is set equal to that anchor
if (verts[0] == verts[1]) return;
// Check zero fill
if (vec3(v_color[0].a, v_color[1].a, v_color[2].a) == vec3(0.0, 0.0, 0.0)) return;
if(winding){
// Emit main triangle
fill_all = 1.0;

View file

@ -9,7 +9,7 @@ out vec3 verts; // Bezier control point
out vec4 v_color;
out vec3 v_base_point;
out vec3 v_unit_normal;
out float v_vert_index;
out int v_vert_index;
void main(){
verts = point;

View file

@ -14,6 +14,7 @@ in vec3 verts[3];
in vec4 v_joint_product[3];
in float v_stroke_width[3];
in vec4 v_color[3];
in int v_vert_index[3];
out vec4 color;
out float dist_to_curve;
@ -187,10 +188,19 @@ void emit_point_with_width(
void main() {
// Vector graphic shaders use TRIANGLE_STRIP, but only
// every other one needs to be rendered
if (v_vert_index[0] % 2 != 0) return;
// Curves are marked as ended when the handle after
// the first anchor is set equal to that anchor
if (verts[0] == verts[1]) return;
// Check null stroke
if (vec3(v_stroke_width[0], v_stroke_width[1], v_stroke_width[2]) == vec3(0.0, 0.0, 0.0)) return;
if (vec3(v_color[0].a, v_color[1].a, v_color[2].a) == vec3(0.0, 0.0, 0.0)) return;
// Coefficients such that the quadratic bezier is c0 + c1 * t + c2 * t^2
vec3 c0 = verts[0];
vec3 c1 = 2 * (verts[1] - verts[0]);

View file

@ -14,6 +14,7 @@ out vec3 verts;
out vec4 v_joint_product;
out float v_stroke_width;
out vec4 v_color;
out int v_vert_index;
const float STROKE_WIDTH_CONVERSION = 0.01;
@ -22,4 +23,5 @@ void main(){
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, is_fixed_in_frame);
v_joint_product = joint_product;
v_color = stroke_rgba;
v_vert_index = gl_VertexID;
}