diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index b3bac0ac..276efd7d 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -11,6 +11,7 @@ import numpy as np from manimlib.constants import * from manimlib.utils.color import color_gradient from manimlib.utils.color import interpolate_color +from manimlib.utils.color import get_colormap_list from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import list_update @@ -1275,12 +1276,17 @@ class Mobject(object): Pass in a glsl expression in terms of x, y and z which returns a float. """ + # TODO, add a version of this which changes the point data instead + # of the shader code for char in "xyz": glsl_snippet = glsl_snippet.replace(char, "point." + char) + rgb_list = get_colormap_list(colormap) self.set_color_by_code( "color.rgb = float_to_color({}, {}, {}, {});".format( - glsl_snippet, float(min_value), float(max_value), - get_colormap_code(colormap) + glsl_snippet, + float(min_value), + float(max_value), + get_colormap_code(rgb_list) ) ) return self diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index 375384a1..282deb8d 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -3,7 +3,6 @@ import re import moderngl import numpy as np import copy -from matplotlib.cm import get_cmap from manimlib.utils.directories import get_shader_dir from manimlib.utils.file_ops import find_file @@ -153,14 +152,9 @@ def get_shader_code_from_file(filename): return result -def get_colormap_code(colormap="viridis"): - rgbs = get_cmap(colormap).colors # Make more general? - sparse_rgbs = [ - rgbs[int(n)] - for n in np.linspace(0, len(rgbs) - 1, 9) - ] +def get_colormap_code(rgb_list): data = ",".join( - "vec3({}, {}, {})".format(*color) - for color in sparse_rgbs + "vec3({}, {}, {})".format(*rgb) + for rgb in rgb_list ) - return f"vec3[9]({data})" + return f"vec3[{len(rgb_list)}]({data})" diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index a362b128..efe5e5b7 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -2,6 +2,7 @@ import random from colour import Color import numpy as np +from matplotlib.cm import get_cmap from manimlib.constants import PALETTE from manimlib.constants import WHITE @@ -112,3 +113,11 @@ def get_shaded_rgb(rgb, point, unit_normal_vect, light_source): result = rgb + factor clip_in_place(rgb + factor, 0, 1) return result + + +def get_colormap_list(map_name="viridis", n_colors=9): + rgbs = get_cmap(map_name).colors # Make more general? + return [ + rgbs[int(n)] + for n in np.linspace(0, len(rgbs) - 1, n_colors) + ]