From 24b160f9f9c3fde3a6c0f735f75df259173e32c3 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 19 Aug 2024 08:05:32 -0500 Subject: [PATCH] Update VMobject shader wrapper Use a combined VBO Render with TRIANGLE_STRIP, and ignore every other --- manimlib/mobject/mobject.py | 2 - manimlib/mobject/types/vectorized_mobject.py | 125 ++------- manimlib/shader_wrapper.py | 253 ++++++++++++------ .../shaders/quadratic_bezier_fill/geom.glsl | 10 +- .../shaders/quadratic_bezier_fill/vert.glsl | 2 +- .../shaders/quadratic_bezier_stroke/geom.glsl | 10 + .../shaders/quadratic_bezier_stroke/vert.glsl | 2 + 7 files changed, 213 insertions(+), 191 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 6238c4de..fad07aae 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -2049,8 +2049,6 @@ class Mobject(object): def render(self, ctx: Context, camera_uniforms: dict): if self._data_has_changed: self.shader_wrappers = self.get_shader_wrapper_list(ctx) - for shader_wrapper in self.shader_wrappers: - shader_wrapper.load_data() self._data_has_changed = False for shader_wrapper in self.shader_wrappers: shader_wrapper.update_program_uniforms(camera_uniforms) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 2ad4edd5..3bbf3820 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -5,6 +5,7 @@ from functools import wraps import moderngl import numpy as np import operator as op +import itertools as it from manimlib.constants import GREY_A, GREY_C, GREY_E from manimlib.constants import BLACK @@ -46,7 +47,7 @@ from manimlib.utils.space_ops import rotation_between_vectors from manimlib.utils.space_ops import poly_line_length from manimlib.utils.space_ops import z_to_vector from manimlib.shader_wrapper import ShaderWrapper -from manimlib.shader_wrapper import FillShaderWrapper +from manimlib.shader_wrapper import VShaderWrapper from typing import TYPE_CHECKING from typing import Generic, TypeVar, Iterable @@ -74,8 +75,8 @@ class VMobject(Mobject): ('unit_normal', np.float32, (3,)), ('fill_border_width', np.float32, (1,)), ]) - fill_data_names = ['point', 'fill_rgba', 'base_point', 'unit_normal'] - stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_product'] + fill_data_names = ['point', 'fill_rgba', 'base_point', 'unit_normal'] # Delete these + stroke_data_names = ['point', 'stroke_rgba', 'stroke_width', 'joint_product'] # Delete these fill_render_primitive: int = moderngl.TRIANGLES stroke_render_primitive: int = moderngl.TRIANGLES @@ -1325,49 +1326,17 @@ class VMobject(Mobject): return self # For shaders - def init_shader_data(self, ctx: Context): - dtype = self.shader_dtype - fill_dtype, stroke_dtype = ( - np.dtype([ - (name, dtype[name].base, dtype[name].shape) - for name in names - ]) - for names in [self.fill_data_names, self.stroke_data_names] - ) - fill_data = np.zeros(0, dtype=fill_dtype) - stroke_data = np.zeros(0, dtype=stroke_dtype) - self.fill_shader_wrapper = FillShaderWrapper( - ctx=ctx, - vert_data=fill_data, - mobject_uniforms=self.uniforms, - shader_folder=self.fill_shader_folder, - render_primitive=self.fill_render_primitive, - ) - self.stroke_shader_wrapper = ShaderWrapper( - ctx=ctx, - vert_data=stroke_data, - mobject_uniforms=self.uniforms, - shader_folder=self.stroke_shader_folder, - render_primitive=self.stroke_render_primitive, - ) - self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy() - self.shader_wrappers = [ - self.back_stroke_shader_wrapper, - self.fill_shader_wrapper, - self.stroke_shader_wrapper, - ] - for sw in self.shader_wrappers: - family = self.family_members_with_points() - rep = family[0] if family else self - for old, new in rep.shader_code_replacements.items(): - sw.replace_code(old, new) - def refresh_shader_wrapper_id(self) -> Self: - if not self._shaders_initialized: - return self - for wrapper in self.shader_wrappers: - wrapper.refresh_id() - return self + def init_shader_data(self, ctx: Context): + self.shader_indices = np.zeros(0) + self.shader_wrapper = VShaderWrapper( + ctx=ctx, + vert_data=self.data, + mobject_uniforms=self.uniforms, + ) + + def get_shader_vert_indices(self): + return self.get_outer_vert_indices() def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]: if not self._shaders_initialized: @@ -1377,62 +1346,20 @@ class VMobject(Mobject): family = self.family_members_with_points() if not family: return [] - fill_names = self.fill_data_names - stroke_names = self.stroke_data_names - fill_family = (sm for sm in family if sm._has_fill) - stroke_family = (sm for sm in family if sm._has_stroke) + for submob in family: + if submob._has_fill: + submob.data["base_point"] = submob.data["point"][0] - # Build up fill data lists - fill_datas = [] - fill_indices = [] - fill_border_datas = [] - for submob in fill_family: - indices = submob.get_outer_vert_indices() - if submob._use_winding_fill: - data = submob.data[fill_names] - data["base_point"][:] = data["point"][0] - fill_datas.append(data[indices]) - else: - fill_datas.append(submob.data[fill_names]) - fill_indices.append(submob.get_triangulation()) - - draw_border_width = op.and_( - submob.data['fill_border_width'][0] > 0, - (not submob._has_stroke) or submob.stroke_behind, - ) - if draw_border_width: - # Add fill border - submob.get_joint_products() - names = list(stroke_names) - names[names.index('stroke_rgba')] = 'fill_rgba' - names[names.index('stroke_width')] = 'fill_border_width' - border_stroke_data = submob.data[names].astype( - self.stroke_shader_wrapper.vert_data.dtype - ) - fill_border_datas.append(border_stroke_data[indices]) - - # Build up stroke data lists - stroke_datas = [] - back_stroke_datas = [] - for submob in stroke_family: - submob.get_joint_products() - indices = submob.get_outer_vert_indices() - if submob.stroke_behind: - back_stroke_datas.append(submob.data[stroke_names][indices]) - else: - stroke_datas.append(submob.data[stroke_names][indices]) - - shader_wrappers = [ - self.back_stroke_shader_wrapper.read_in([*back_stroke_datas, *fill_border_datas]), - self.fill_shader_wrapper.read_in(fill_datas, fill_indices or None), - self.stroke_shader_wrapper.read_in(stroke_datas), - ] - for sw in shader_wrappers: - rep = family[0] # Representative family member - sw.bind_to_mobject_uniforms(rep.get_uniforms()) - sw.depth_test = rep.depth_test - return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] + self.shader_wrapper.read_in( + # [sm.data for sm in family], + list(it.chain(*([sm.data, sm.data[-1:]] for sm in family))) + # [sm.get_shader_vert_indices() for sm in family] + ) + rep = family[0] # Representative family member + self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms()) + self.shader_wrapper.depth_test = rep.depth_test + return [self.shader_wrapper] class VGroup(Group, VMobject, Generic[SubVmobjectType]): diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index 191369ce..5445c06c 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -36,7 +36,6 @@ class ShaderWrapper(object): self, ctx: moderngl.context.Context, vert_data: np.ndarray, - vert_indices: Optional[np.ndarray] = None, shader_folder: Optional[str] = None, mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. @@ -45,7 +44,6 @@ class ShaderWrapper(object): ): self.ctx = ctx self.vert_data = vert_data - self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_attributes = vert_data.dtype.names self.shader_folder = shader_folder self.depth_test = depth_test @@ -59,7 +57,7 @@ class ShaderWrapper(object): self.texture_names_to_ids = dict() if texture_paths is not None: self.init_textures(texture_paths) - self.init_vao() + self.init_vertex_objects() self.refresh_id() def init_program_code(self) -> None: @@ -81,6 +79,7 @@ class ShaderWrapper(object): return self.program = get_shader_program(self.ctx, **self.program_code) self.vert_format = moderngl.detect_format(self.program, self.vert_attributes) + self.programs = [self.program] def init_textures(self, texture_paths: dict[str, str]): self.texture_names_to_ids = { @@ -88,9 +87,8 @@ class ShaderWrapper(object): for name, path in texture_paths.items() } - def init_vao(self): + def init_vertex_objects(self): self.vbo = None - self.ibo = None self.vao = None def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict): @@ -98,8 +96,7 @@ class ShaderWrapper(object): def __eq__(self, shader_wrapper: ShaderWrapper): return all(( - np.all(self.vert_data == shader_wrapper.vert_data), - np.all(self.vert_indices == shader_wrapper.vert_indices), + # np.all(self.vert_data == shader_wrapper.vert_data), self.shader_folder == shader_wrapper.shader_folder, all( self.mobject_uniforms[key] == shader_wrapper.mobject_uniforms[key] @@ -113,8 +110,7 @@ class ShaderWrapper(object): result = copy.copy(self) result.ctx = self.ctx result.vert_data = self.vert_data.copy() - result.vert_indices = self.vert_indices.copy() - result.init_vao() + result.init_vertex_objects() return result def is_valid(self) -> bool: @@ -129,10 +125,7 @@ class ShaderWrapper(object): def create_id(self) -> str: # A unique id for a shader - program_id = hash("".join( - self.program_code[f"{name}_shader"] or "" - for name in ("vertex", "geometry", "fragment") - )) + program_id = hash("".join(map(str, self.program_code.values()))) return "|".join(map(str, [ program_id, self.mobject_uniforms, @@ -173,45 +166,52 @@ class ShaderWrapper(object): def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: if len(shader_wrappers) > 0: - data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)] - indices_list = [self.vert_indices, *(sw.vert_indices for sw in shader_wrappers)] - self.read_in(data_list, indices_list) + self.read_in([self.vert_data, (sw.vert_data for sw in shader_wrappers)]) + vbos = [ + vbo + for vbo in [self.vbo, *(sw.vbo for sw in shader_wrappers)] + if vbo is not None + ] + total_size = sum(vbo.size for vbo in vbos) + new_vbo = self.ctx.buffer(reserve=total_size) + offset = 0 + for vbo in vbos: + new_vbo.write(vbo.read(), offset=offset) + offset += vbo.size + self.vbo = new_vbo return self def read_in( self, - data_list: List[np.ndarray], - indices_list: List[np.ndarray] | None = None - ) -> ShaderWrapper: - # Assume all are of the same type - total_len = sum(len(data) for data in data_list) - self.vert_data = resize_array(self.vert_data, total_len) + data_list: Iterable[np.ndarray], + indices_list: Iterable[np.ndarray] | None = None + ): + if indices_list is not None: + data_list = [data[indices] for data, indices in zip(data_list, indices_list)] + total_len = sum(map(len, indices_list)) + else: + total_len = sum(map(len, data_list)) + if total_len == 0: - return self + if self.vbo is not None: + self.vbo.clear() + return - # Stack the data - np.concatenate(data_list, out=self.vert_data) + # If possible, read concatenated data into existing list + if len(self.vert_data) != total_len: + self.vert_data = np.concatenate(data_list) + else: + np.concatenate(data_list, out=self.vert_data) - if indices_list is None: - self.vert_indices = resize_array(self.vert_indices, 0) - return self - - total_verts = sum(len(vi) for vi in indices_list) - if total_verts == 0: - return self - - self.vert_indices = resize_array(self.vert_indices, total_verts) - - # Stack vert_indices, but adding the appropriate offset - # alogn the way - n_points = 0 - n_verts = 0 - for data, indices in zip(data_list, indices_list): - new_n_verts = n_verts + len(indices) - self.vert_indices[n_verts:new_n_verts] = indices + n_points - n_verts = new_n_verts - n_points += len(data) - return self + # Either create new vbo, or read data into it + total_size = self.vert_data.itemsize * total_len + if self.vbo is None: + self.vbo = self.ctx.buffer(self.vert_data) + elif self.vbo.size != total_size: + self.vbo.release() + self.vbo = self.ctx.buffer(self.vert_data) + else: + self.vbo.write(self.vert_data) # Related to data and rendering def pre_render(self): @@ -219,74 +219,125 @@ class ShaderWrapper(object): self.set_ctx_clip_plane(self.use_clip_plane()) def render(self): - assert self.vao is not None + if self.vao is None: + self.generate_vao() self.vao.render() def update_program_uniforms(self, camera_uniforms: UniformDict): - if self.program is None: - return - for uniforms in [self.mobject_uniforms, camera_uniforms, self.texture_names_to_ids]: - for name, value in uniforms.items(): - set_program_uniform(self.program, name, value) - - def get_vertex_buffer_object(self): - self.vbo = self.ctx.buffer(self.vert_data) - return self.vbo - - def get_index_buffer_object(self): - if len(self.vert_indices) > 0: - self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32)) - return self.ibo - - def load_data(self): - if self.vao is None: - self.generate_vao() - elif self.vao.vertices != len(self.vert_data): - self.release() - self.generate_vao() - else: - self.vbo.write(self.vert_data) - if self.ibo is not None: - self.ibo.write(self.self.vert_indices.astype(np.uint32)) + for program in self.programs: + if program is None: + continue + for uniforms in [self.mobject_uniforms, camera_uniforms, self.texture_names_to_ids]: + for name, value in uniforms.items(): + set_program_uniform(program, name, value) def generate_vao(self): - # Data buffer - vbo = self.get_vertex_buffer_object() - ibo = self.get_index_buffer_object() + if self.vbo is None: + self.vbo = self.ctx.buffer(self.vert_data) # Vertex array object self.vao = self.ctx.vertex_array( program=self.program, - content=[(vbo, self.vert_format, *self.vert_attributes)], - index_buffer=ibo, + content=[(self.vbo, self.vert_format, *self.vert_attributes)], mode=self.render_primitive, ) - return self.vao def release(self): - for obj in (self.vbo, self.ibo, self.vao): + for obj in (self.vbo, self.vao): if obj is not None: obj.release() self.vbo = None - self.ibo = None self.vao = None -class FillShaderWrapper(ShaderWrapper): +class VShaderWrapper(ShaderWrapper): def __init__( self, ctx: moderngl.context.Context, - *args, - **kwargs + vert_data: np.ndarray, + shader_folder: Optional[str] = None, + mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables + texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures. + depth_test: bool = False, + # render_primitive: int = moderngl.TRIANGLES, + render_primitive: int = moderngl.TRIANGLE_STRIP, ): - super().__init__(ctx, *args, **kwargs) + super().__init__( + ctx=ctx, + vert_data=vert_data, + shader_folder=shader_folder, + mobject_uniforms=mobject_uniforms, + texture_paths=texture_paths, + depth_test=depth_test, + render_primitive=render_primitive, + ) self.fill_canvas = get_fill_canvas(self.ctx) - def render(self): - winding = (len(self.vert_indices) == 0) - self.program['winding'].value = winding + def init_program_code(self) -> None: + self.program_code = { + f"{vtype}_{name}": get_shader_code_from_file( + os.path.join(f"quadratic_bezier_{vtype}", f"{name}.glsl") + ) + for vtype in ["stroke", "fill"] + for name in ["vert", "geom", "frag"] + } + + def init_program(self): + self.stroke_program = get_shader_program( + self.ctx, + vertex_shader=self.program_code["stroke_vert"], + geometry_shader=self.program_code["stroke_geom"], + fragment_shader=self.program_code["stroke_frag"], + ) + self.fill_program = get_shader_program( + self.ctx, + vertex_shader=self.program_code["fill_vert"], + geometry_shader=self.program_code["fill_geom"], + fragment_shader=self.program_code["fill_frag"], + ) + self.programs = [self.stroke_program, self.fill_program] + + # Full vert format looks like this (total of 4x23 = 92 bytes): + # point 3 + # stroke_rgba 4 + # stroke_width 1 + # joint_product 4 + # fill_rgba 4 + # base_point 3 + # unit_normal 3 + # fill_border_width 1 + self.stroke_vert_format = '3f 4f 1f 4f 44x' + self.stroke_vert_attributes = ['point', 'stroke_rgba', 'stroke_width', 'joint_product'] + + self.fill_vert_format = '3f 36x 4f 3f 3f 4x' + self.fill_vert_attributes = ['point', 'fill_rgba', 'base_point', 'unit_normal'] + + def init_vertex_objects(self): + self.vbo = None + self.stroke_vao = None + self.fill_vao = None + + # TODO, think about create_id, replace_code + def is_valid(self) -> bool: + return self.vert_data is not None + + # TODO, motidify read in to handle triangulation case for non-winding fill? + + # Rendering + def render_stroke(self): + if self.stroke_vao is None: + return + self.stroke_vao.render() + + def render_fill(self): + if self.fill_vao is None: + return + + # TODO, need a new test here + winding = True + self.fill_program['winding'].value = winding if not winding: - super().render() + self.fill_vao.render() return original_fbo = self.ctx.fbo @@ -302,7 +353,7 @@ class FillShaderWrapper(ShaderWrapper): gl.GL_ONE_MINUS_DST_ALPHA, gl.GL_ONE, ) - super().render() + self.fill_vao.render() original_fbo.use() gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) @@ -310,3 +361,29 @@ class FillShaderWrapper(ShaderWrapper): texture_vao.render() gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) + + def render(self): + if self.stroke_vao is None or self.fill_vao is None: + self.generate_vao() + self.render_fill() + self.render_stroke() + + def generate_vao(self): + self.stroke_vao = self.ctx.vertex_array( + program=self.stroke_program, + content=[(self.vbo, self.stroke_vert_format, *self.stroke_vert_attributes)], + mode=self.render_primitive, + ) + self.fill_vao = self.ctx.vertex_array( + program=self.fill_program, + content=[(self.vbo, self.fill_vert_format, *self.fill_vert_attributes)], + mode=self.render_primitive, + ) + + def release(self): + attrs = ["vbo", "stroke_vao", "fill_vao"] + for attr in attrs: + obj = getattr(self, attr) + if obj is not None: + obj.release() + setattr(self, attr, None) diff --git a/manimlib/shaders/quadratic_bezier_fill/geom.glsl b/manimlib/shaders/quadratic_bezier_fill/geom.glsl index 99e10049..a36ec732 100644 --- a/manimlib/shaders/quadratic_bezier_fill/geom.glsl +++ b/manimlib/shaders/quadratic_bezier_fill/geom.glsl @@ -8,8 +8,8 @@ uniform bool winding; in vec3 verts[3]; in vec4 v_color[3]; in vec3 v_base_point[3]; -in float v_vert_index[3]; in vec3 v_unit_normal[3]; +in int v_vert_index[3]; out vec4 color; out float fill_all; @@ -57,10 +57,18 @@ void emit_simple_triangle(){ void main(){ + // Vector graphic shaders use TRIANGLE_STRIP, but only + // every other one needs to be rendered + if (v_vert_index[0] % 2 != 0) return; + // Curves are marked as ended when the handle after // the first anchor is set equal to that anchor if (verts[0] == verts[1]) return; + // Check zero fill + if (vec3(v_color[0].a, v_color[1].a, v_color[2].a) == vec3(0.0, 0.0, 0.0)) return; + + if(winding){ // Emit main triangle fill_all = 1.0; diff --git a/manimlib/shaders/quadratic_bezier_fill/vert.glsl b/manimlib/shaders/quadratic_bezier_fill/vert.glsl index a15752c4..a01cf22e 100644 --- a/manimlib/shaders/quadratic_bezier_fill/vert.glsl +++ b/manimlib/shaders/quadratic_bezier_fill/vert.glsl @@ -9,7 +9,7 @@ out vec3 verts; // Bezier control point out vec4 v_color; out vec3 v_base_point; out vec3 v_unit_normal; -out float v_vert_index; +out int v_vert_index; void main(){ verts = point; diff --git a/manimlib/shaders/quadratic_bezier_stroke/geom.glsl b/manimlib/shaders/quadratic_bezier_stroke/geom.glsl index ede46580..a34db252 100644 --- a/manimlib/shaders/quadratic_bezier_stroke/geom.glsl +++ b/manimlib/shaders/quadratic_bezier_stroke/geom.glsl @@ -14,6 +14,7 @@ in vec3 verts[3]; in vec4 v_joint_product[3]; in float v_stroke_width[3]; in vec4 v_color[3]; +in int v_vert_index[3]; out vec4 color; out float dist_to_curve; @@ -187,10 +188,19 @@ void emit_point_with_width( void main() { + // Vector graphic shaders use TRIANGLE_STRIP, but only + // every other one needs to be rendered + if (v_vert_index[0] % 2 != 0) return; + // Curves are marked as ended when the handle after // the first anchor is set equal to that anchor if (verts[0] == verts[1]) return; + // Check null stroke + if (vec3(v_stroke_width[0], v_stroke_width[1], v_stroke_width[2]) == vec3(0.0, 0.0, 0.0)) return; + if (vec3(v_color[0].a, v_color[1].a, v_color[2].a) == vec3(0.0, 0.0, 0.0)) return; + + // Coefficients such that the quadratic bezier is c0 + c1 * t + c2 * t^2 vec3 c0 = verts[0]; vec3 c1 = 2 * (verts[1] - verts[0]); diff --git a/manimlib/shaders/quadratic_bezier_stroke/vert.glsl b/manimlib/shaders/quadratic_bezier_stroke/vert.glsl index bca72d66..1095bf7a 100644 --- a/manimlib/shaders/quadratic_bezier_stroke/vert.glsl +++ b/manimlib/shaders/quadratic_bezier_stroke/vert.glsl @@ -14,6 +14,7 @@ out vec3 verts; out vec4 v_joint_product; out float v_stroke_width; out vec4 v_color; +out int v_vert_index; const float STROKE_WIDTH_CONVERSION = 0.01; @@ -22,4 +23,5 @@ void main(){ v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, is_fixed_in_frame); v_joint_product = joint_product; v_color = stroke_rgba; + v_vert_index = gl_VertexID; } \ No newline at end of file