Add ability to manipulate a mobjects shader code from python, and in particular to give it a coloring function, e.g. with Mobject.set_color_by_xyz_func

This commit is contained in:
Grant Sanderson 2021-01-09 18:52:54 -08:00
parent a7af5e72c6
commit fdcc8d4257
4 changed files with 125 additions and 14 deletions

View file

@ -136,3 +136,17 @@ PALETTE = list(COLOR_MAP.values())
locals().update(COLOR_MAP)
for name in [s for s in list(COLOR_MAP.keys()) if s.endswith("_C")]:
locals()[name.replace("_C", "")] = locals()[name]
COLORMAPS = {
# From https://bids.github.io/colormap/
"viridis": [[0.267004, 0.004874, 0.329415],
[0.279574, 0.170599, 0.479997],
[0.231674, 0.318106, 0.544834],
[0.174274, 0.445044, 0.557792],
[0.128729, 0.563265, 0.551229],
[0.153894, 0.680203, 0.504172],
[0.360741, 0.785964, 0.387814],
[0.668054, 0.861999, 0.196293],
[0.993248, 0.906157, 0.143936]],
# TODO, add other options based on colormap
}

View file

@ -21,6 +21,7 @@ from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import get_colormap_code
# TODO: Explain array_attrs
@ -1242,7 +1243,52 @@ class Mobject(object):
self.depth_test = False
return self
# For shaders
# Shader code manipulation
def replace_shader_code(self, old, new):
for wrapper in self.get_shader_wrapper_list():
wrapper.replace_code(old, new)
return self
def refresh_shader_code(self):
for wrapper in self.get_shader_wrapper_list():
wrapper.init_program_code()
wrapper.refresh_id()
return self
def set_color_by_code(self, glsl_code):
"""
Takes a snippet of code and inserts it into a
context which has the following variables:
vec4 color, vec3 point, vec3 unit_normal.
The code should change the color variable
"""
self.replace_shader_code(
"///// INSERT COLOR FUNCTION HERE /////",
glsl_code
)
return self
def set_color_by_xyz_func(self, glsl_snippet,
min_value=-5.0, max_value=5.0,
colormap="viridis"):
"""
Pass in a glsl expression in terms of x, y and z which returns
a float.
"""
for char in "xyz":
glsl_snippet = glsl_snippet.replace(char, "point." + char)
self.replace_shader_code(
"///// INSERT COLOR_MAP FUNCTION HERE /////",
get_colormap_code(colormap)
)
self.set_color_by_code(
"color.rgb = colormap({}, {}, {});".format(
glsl_snippet, float(min_value), float(max_value)
)
)
return self
# For shader data
def init_shader_data(self):
self.shader_data = np.zeros(len(self.points), dtype=self.shader_dtype)
self.shader_indices = None

View file

@ -4,6 +4,7 @@ import moderngl
import numpy as np
import copy
from manimlib.constants import COLORMAPS
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
@ -32,8 +33,8 @@ class ShaderWrapper(object):
self.texture_paths = texture_paths or dict()
self.depth_test = depth_test
self.render_primitive = str(render_primitive)
self.id = self.create_id()
self.program_id = self.create_program_id()
self.init_program_code()
self.refresh_id()
def copy(self):
result = copy.copy(self)
@ -49,7 +50,8 @@ class ShaderWrapper(object):
def is_valid(self):
return all([
self.vert_data is not None,
self.shader_folder,
self.program_code["vertex_shader"] is not None,
self.program_code["fragment_shader"] is not None,
])
def get_id(self):
@ -61,7 +63,7 @@ class ShaderWrapper(object):
def create_id(self):
# A unique id for a shader
return "|".join(map(str, [
self.shader_folder,
self.program_id,
self.uniforms,
self.texture_paths,
self.depth_test,
@ -69,23 +71,38 @@ class ShaderWrapper(object):
]))
def refresh_id(self):
self.program_id = self.create_program_id()
self.id = self.create_id()
def create_program_id(self):
return self.shader_folder
return hash("".join((
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)))
def get_program_code(self):
def init_program_code(self):
def get_code(name):
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
return {
self.program_code = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def get_program_code(self):
return self.program_code
def replace_code(self, old, new):
code_map = self.program_code
for (name, code) in code_map.items():
if code_map[name] is None:
continue
code_map[name] = re.sub(old, new, code_map[name])
self.refresh_id()
def combine_with(self, *shader_wrappers):
# Assume they are of the same type
if len(shader_wrappers) == 0:
@ -132,3 +149,26 @@ def get_shader_code_from_file(filename):
)
result = result.replace(line, inserted_code)
return result
def get_colormap_code(colormap="viridis"):
code = """
const vec3[9] COLOR_MAP_DATA = vec3[9](// INSERT DATA //);
vec3 colormap(float value, float min_val, float max_val){
float alpha = smoothstep(min_val, max_val, value);
int disc_alpha = min(int(alpha * 8), 7);
return mix(
COLOR_MAP_DATA[disc_alpha],
COLOR_MAP_DATA[disc_alpha + 1],
8.0 * alpha - disc_alpha
);
}
"""
data = COLORMAPS[colormap]
insertion = "".join(
"vec3({}, {}, {}),".format(*vect)
for vect in data
)
insertion = insertion[:-1] # Remove final comma
return code.replace("// INSERT DATA //", insertion)

View file

@ -1,12 +1,23 @@
vec4 add_light(vec4 raw_color, vec3 point, vec3 unit_normal, vec3 light_coords, float gloss, float shadow){
if(gloss == 0.0 && shadow == 0.0) return raw_color;
///// INSERT COLOR_MAP FUNCTION HERE /////
vec4 add_light(vec4 color,
vec3 point,
vec3 unit_normal,
vec3 light_coords,
float gloss,
float shadow){
///// INSERT COLOR FUNCTION HERE /////
// The line above may be replaced by arbitrary code snippets, as per
// the method Mobject.set_color_by_code
if(gloss == 0.0 && shadow == 0.0) return color;
// TODO, do we actually want this? It effectively treats surfaces as two-sided
if(unit_normal.z < 0){
unit_normal *= -1;
unit_normal *= -1;
}
float camera_distance = 6; // TODO, read this in as a uniform?
// TODO, read this in as a uniform?
float camera_distance = 6;
// Assume everything has already been rotated such that camera is in the z-direction
vec3 to_camera = vec3(0, 0, camera_distance) - point;
vec3 to_light = light_coords - point;
@ -16,7 +27,7 @@ vec4 add_light(vec4 raw_color, vec3 point, vec3 unit_normal, vec3 light_coords,
float dp2 = dot(normalize(to_light), unit_normal);
float darkening = mix(1, max(dp2, 0), shadow);
return vec4(
darkening * mix(raw_color.rgb, vec3(1.0), shine),
raw_color.a
darkening * mix(color.rgb, vec3(1.0), shine),
color.a
);
}