Unify shader_dtype and data_dtype

This commit is contained in:
Grant Sanderson 2023-01-15 19:09:29 -08:00
parent 3f2fd5b142
commit 6f9f83fb1b
9 changed files with 45 additions and 62 deletions

View file

@ -64,11 +64,7 @@ class Mobject(object):
shader_folder: str = "" shader_folder: str = ""
render_primitive: int = moderngl.TRIANGLE_STRIP render_primitive: int = moderngl.TRIANGLE_STRIP
# Must match in attributes of vert shader # Must match in attributes of vert shader
shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ shader_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)),
('rgba', np.float32, (4,)),
]
data_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)), ('point', np.float32, (3,)),
('rgba', np.float32, (4,)), ('rgba', np.float32, (4,)),
]) ])
@ -135,7 +131,7 @@ class Mobject(object):
return self.replicate(other) return self.replicate(other)
def init_data(self, length: int = 0): def init_data(self, length: int = 0):
self.data = np.zeros(length, dtype=self.data_dtype) self.data = np.zeros(length, dtype=self.shader_dtype)
def init_uniforms(self): def init_uniforms(self):
self.uniforms: dict[str, float | np.ndarray] = { self.uniforms: dict[str, float | np.ndarray] = {

View file

@ -25,15 +25,11 @@ class DotCloud(PMobject):
shader_folder: str = "true_dot" shader_folder: str = "true_dot"
render_primitive: int = moderngl.POINTS render_primitive: int = moderngl.POINTS
shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)),
('radius', np.float32, (1,)),
('color', np.float32, (4,)),
]
data_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)), ('point', np.float32, (3,)),
('radius', np.float32, (1,)), ('radius', np.float32, (1,)),
('rgba', np.float32, (4,)), ('rgba', np.float32, (4,)),
]) ]
def __init__( def __init__(
self, self,
points: Vect3Array = NULL_POINTS, points: Vect3Array = NULL_POINTS,
@ -150,7 +146,7 @@ class DotCloud(PMobject):
def get_shader_data(self) -> np.ndarray: def get_shader_data(self) -> np.ndarray:
shader_data = super().get_shader_data() shader_data = super().get_shader_data()
self.read_data_to_shader(shader_data, "radius", "radius") self.read_data_to_shader(shader_data, "radius", "radius")
self.read_data_to_shader(shader_data, "color", "rgba") self.read_data_to_shader(shader_data, "rgba", "rgba")
return shader_data return shader_data

View file

@ -24,11 +24,6 @@ class ImageMobject(Mobject):
('im_coords', np.float32, (2,)), ('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)), ('opacity', np.float32, (1,)),
] ]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)),
('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)),
]
def __init__( def __init__(
self, self,

View file

@ -27,12 +27,6 @@ class Surface(Mobject):
render_primitive: int = moderngl.TRIANGLES render_primitive: int = moderngl.TRIANGLES
shader_folder: str = "surface" shader_folder: str = "surface"
shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)),
('du_point', np.float32, (3,)),
('dv_point', np.float32, (3,)),
('color', np.float32, (4,)),
]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)), ('point', np.float32, (3,)),
('du_point', np.float32, (3,)), ('du_point', np.float32, (3,)),
('dv_point', np.float32, (3,)), ('dv_point', np.float32, (3,)),
@ -249,7 +243,7 @@ class Surface(Mobject):
return shader_data return shader_data
def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray: def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray:
self.read_data_to_shader(shader_data, "color", "rgba") self.read_data_to_shader(shader_data, "rgba", "rgba")
return shader_data return shader_data
def get_shader_vert_indices(self) -> np.ndarray: def get_shader_vert_indices(self) -> np.ndarray:
@ -293,13 +287,6 @@ class TexturedSurface(Surface):
('im_coords', np.float32, (2,)), ('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)), ('opacity', np.float32, (1,)),
] ]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
('point', np.float32, (3,)),
('du_point', np.float32, (3,)),
('dv_point', np.float32, (3,)),
('im_coords', np.float32, (2,)),
('opacity', np.float32, (1,)),
]
def __init__( def __init__(
self, self,

View file

@ -54,26 +54,28 @@ DEFAULT_FILL_COLOR = GREY_C
class VMobject(Mobject): class VMobject(Mobject):
fill_shader_folder: str = "quadratic_bezier_fill" fill_shader_folder: str = "quadratic_bezier_fill"
stroke_shader_folder: str = "quadratic_bezier_stroke" stroke_shader_folder: str = "quadratic_bezier_stroke"
fill_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ shader_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)), ('point', np.float32, (3,)),
('orientation', np.float32, (1,)), ('stroke_rgba', np.float32, (4,)),
('color', np.float32, (4,)), ('stroke_width', np.float32, (1,)),
('vert_index', np.float32, (1,)), ('joint_angle', np.float32, (1,)),
]
stroke_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
("point", np.float32, (3,)),
("joint_angle", np.float32, (1,)),
("stroke_width", np.float32, (1,)),
("color", np.float32, (4,)),
]
data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
("point", np.float32, (3,)),
('fill_rgba', np.float32, (4,)), ('fill_rgba', np.float32, (4,)),
("stroke_rgba", np.float32, (4,)),
("joint_angle", np.float32, (1,)),
("stroke_width", np.float32, (1,)),
('orientation', np.float32, (1,)), ('orientation', np.float32, (1,)),
] ('vert_index', np.float32, (1,)),
])
fill_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)),
('fill_rgba', np.float32, (4,)),
('orientation', np.float32, (1,)),
('vert_index', np.float32, (1,)),
])
stroke_dtype: np.dtype = np.dtype([
('point', np.float32, (3,)),
('stroke_rgba', np.float32, (4,)),
('stroke_width', np.float32, (1,)),
('joint_angle', np.float32, (1,)),
])
fill_render_primitive: int = moderngl.TRIANGLES fill_render_primitive: int = moderngl.TRIANGLES
stroke_render_primitive: int = moderngl.TRIANGLE_STRIP stroke_render_primitive: int = moderngl.TRIANGLE_STRIP
@ -1146,6 +1148,14 @@ class VMobject(Mobject):
super().set_data(data) super().set_data(data)
return self return self
def resize_points(
self,
new_length: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
super().resize_points(new_length, resize_func)
self.data["vert_index"][:, 0] = np.arange(new_length)
# TODO, how to be smart about tangents here? # TODO, how to be smart about tangents here?
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def apply_function( def apply_function(
@ -1247,7 +1257,7 @@ class VMobject(Mobject):
return self.stroke_data return self.stroke_data
self.read_data_to_shader(self.stroke_data[:n], "point", "point") self.read_data_to_shader(self.stroke_data[:n], "point", "point")
self.read_data_to_shader(self.stroke_data[:n], "color", "stroke_rgba") self.read_data_to_shader(self.stroke_data[:n], "stroke_rgba", "stroke_rgba")
self.read_data_to_shader(self.stroke_data[:n], "stroke_width", "stroke_width") self.read_data_to_shader(self.stroke_data[:n], "stroke_width", "stroke_width")
self.get_joint_angles() # Recomputes, only if refresh is needed self.get_joint_angles() # Recomputes, only if refresh is needed
self.read_data_to_shader(self.stroke_data[:n], "joint_angle", "joint_angle") self.read_data_to_shader(self.stroke_data[:n], "joint_angle", "joint_angle")
@ -1262,7 +1272,7 @@ class VMobject(Mobject):
self.fill_data["vert_index"][:, 0] = range(len(points)) self.fill_data["vert_index"][:, 0] = range(len(points))
self.read_data_to_shader(self.fill_data, "point", "point") self.read_data_to_shader(self.fill_data, "point", "point")
self.read_data_to_shader(self.fill_data, "color", "fill_rgba") self.read_data_to_shader(self.fill_data, "fill_rgba", "fill_rgba")
self.read_data_to_shader(self.fill_data, "orientation", "orientation") self.read_data_to_shader(self.fill_data, "orientation", "orientation")
return self.fill_data return self.fill_data

View file

@ -1,8 +1,8 @@
#version 330 #version 330
in vec3 point; in vec3 point;
in vec4 fill_rgba;
in float orientation; in float orientation;
in vec4 color;
in float vert_index; in float vert_index;
out vec3 verts; // Bezier control point out vec3 verts; // Bezier control point
@ -16,6 +16,6 @@ out float v_vert_index;
void main(){ void main(){
verts = position_point_into_frame(point); verts = position_point_into_frame(point);
v_orientation = orientation; v_orientation = orientation;
v_color = color; v_color = fill_rgba;
v_vert_index = vert_index; v_vert_index = vert_index;
} }

View file

@ -1,10 +1,9 @@
#version 330 #version 330
in vec3 point; in vec3 point;
in vec4 stroke_rgba;
in float joint_angle;
in float stroke_width; in float stroke_width;
in vec4 color; in float joint_angle;
// Bezier control point // Bezier control point
out vec3 verts; out vec3 verts;
@ -23,6 +22,6 @@ void main(){
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * frame_shape[1] / 8.0; v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * frame_shape[1] / 8.0;
v_joint_angle = joint_angle; v_joint_angle = joint_angle;
v_color = color; v_color = stroke_rgba;
v_vert_index = gl_VertexID; v_vert_index = gl_VertexID;
} }

View file

@ -5,7 +5,7 @@ uniform vec4 clip_plane;
in vec3 point; in vec3 point;
in vec3 du_point; in vec3 du_point;
in vec3 dv_point; in vec3 dv_point;
in vec4 color; in vec4 rgba;
out vec3 xyz_coords; out vec3 xyz_coords;
out vec3 v_normal; out vec3 v_normal;
@ -18,7 +18,7 @@ out vec4 v_color;
void main(){ void main(){
xyz_coords = position_point_into_frame(point); xyz_coords = position_point_into_frame(point);
v_normal = get_rotated_surface_unit_normal_vector(point, du_point, dv_point); v_normal = get_rotated_surface_unit_normal_vector(point, du_point, dv_point);
v_color = color; v_color = rgba;
gl_Position = get_gl_Position(xyz_coords); gl_Position = get_gl_Position(xyz_coords);
if(clip_plane.xyz != vec3(0.0, 0.0, 0.0)){ if(clip_plane.xyz != vec3(0.0, 0.0, 0.0)){
@ -26,7 +26,7 @@ void main(){
} }
v_color = finalize_color( v_color = finalize_color(
color, rgba,
xyz_coords, xyz_coords,
v_normal, v_normal,
light_source_position, light_source_position,

View file

@ -2,7 +2,7 @@
in vec3 point; in vec3 point;
in float radius; in float radius;
in vec4 color; in vec4 rgba;
out vec3 v_point; out vec3 v_point;
out float v_radius; out float v_radius;
@ -13,5 +13,5 @@ out vec4 v_color;
void main(){ void main(){
v_point = position_point_into_frame(point); v_point = position_point_into_frame(point);
v_radius = radius; v_radius = radius;
v_color = color; v_color = rgba;
} }