Moved around where batching mobjects by shader type occurs, while also pulling out some of the shader_id helper functions

This commit is contained in:
Grant Sanderson 2020-02-17 12:14:40 -08:00
parent 673b85f129
commit 9d4b16d03f
4 changed files with 179 additions and 80 deletions

View file

@ -13,6 +13,9 @@ from manimlib.mobject.mobject import Mobject
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import batch_by_property
from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import fdiv
from manimlib.utils.shaders import shader_info_to_id
from manimlib.utils.shaders import shader_id_to_info
from manimlib.utils.shaders import get_shader_code_from_file
# TODO, think about how to incorporate perspective, # TODO, think about how to incorporate perspective,
@ -209,17 +212,18 @@ class Camera(object):
# Rendering # Rendering
def capture(self, *mobjects, **kwargs): def capture(self, *mobjects, **kwargs):
shader_infos = list(it.chain(*[ shader_infos = it.chain(*[
mob.get_shader_info_list() mob.get_shader_info_list()
for mob in mobjects for mob in mobjects
])) ])
# TODO, batching works well when the mobjects are already organized, # TODO, batching works well when the mobjects are already organized,
# but can we somehow use z-buffering to better effect here? # but can we somehow use z-buffering to better effect here?
batches = batch_by_property(shader_infos, self.get_shader_id) batches = batch_by_property(shader_infos, shader_info_to_id)
for info_group, sid in batches: for info_group, sid in batches:
shader = self.get_shader(sid) shader = self.get_shader(sid)
data = np.hstack([info["data"] for info in info_group]) data = np.hstack([info["data"] for info in info_group])
render_primative = info_group[0]["render_primative"] render_primative = int(info_group[0]["render_primative"])
self.render_from_shader(shader, data, render_primative) self.render_from_shader(shader, data, render_primative)
# Shaders # Shaders
@ -227,57 +231,24 @@ class Camera(object):
# Initialize with the null id going to None # Initialize with the null id going to None
self.id_to_shader = {"": None} self.id_to_shader = {"": None}
def get_shader_id(self, shader_info):
# A unique id for a shader based on the names of the files holding its code
vert, geom, frag, text = [
shader_info.get(key, "") or ""
for key in ["vert", "geom", "frag", "texture_path"]
]
if not vert or not frag:
# Not an actual shader
return ""
return "|".join([vert, geom, frag, text])
def get_shader(self, sid): def get_shader(self, sid):
if sid not in self.id_to_shader: if sid not in self.id_to_shader:
vert, geom, frag, text = sid.split("|") info = shader_id_to_info(sid)
shader = self.ctx.program( shader = self.ctx.program(
vertex_shader=self.get_shader_code_from_file(vert), vertex_shader=get_shader_code_from_file(info["vert"]),
geometry_shader=self.get_shader_code_from_file(geom), geometry_shader=get_shader_code_from_file(info["geom"]),
fragment_shader=self.get_shader_code_from_file(frag), fragment_shader=get_shader_code_from_file(info["frag"]),
) )
if text: if info["texture_path"]:
# TODO, this currently assumes that the uniform Sampler2D # TODO, this currently assumes that the uniform Sampler2D
# is named Texture # is named Texture
tid = self.get_texture_id(text) tid = self.get_texture_id(info["texture_path"])
shader["Texture"].value = tid shader["Texture"].value = tid
self.set_shader_uniforms(shader) self.set_shader_uniforms(shader)
self.id_to_shader[sid] = shader self.id_to_shader[sid] = shader
return self.id_to_shader[sid] return self.id_to_shader[sid]
def get_shader_code_from_file(self, filename):
if len(filename) == 0:
return None
filepath = os.path.join(SHADER_DIR, filename)
if not os.path.exists(filepath):
warnings.warn(f"No file at {file_path}")
return
with open(filepath, "r") as f:
result = f.read()
# To share functionality between shaders, some functions are read in
# from other files an inserted into the relevant strings before
# passing to ctx.program for compiling
# Replace "#INSERT " lines with relevant code
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
for line in insertions:
inserted_code = self.get_shader_code_from_file(line.replace("#INSERT ", ""))
result = result.replace(line, inserted_code)
return result
def set_shader_uniforms(self, shader): def set_shader_uniforms(self, shader):
if shader is None: if shader is None:
return return

View file

@ -23,6 +23,7 @@ from manimlib.utils.space_ops import cross2d
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import angle_between_vectors from manimlib.utils.space_ops import angle_between_vectors
from manimlib.utils.space_ops import earclip_triangulation from manimlib.utils.space_ops import earclip_triangulation
from manimlib.utils.shaders import get_shader_info
class VMobject(Mobject): class VMobject(Mobject):
@ -368,6 +369,7 @@ class VMobject(Mobject):
# Points # Points
def set_points(self, points): def set_points(self, points):
self.points = np.array(points) self.points = np.array(points)
self.refresh_triangulation()
return self return self
def get_points(self): def get_points(self):
@ -525,6 +527,7 @@ class VMobject(Mobject):
anchors[:-1], anchors[1:], 0.5 anchors[:-1], anchors[1:], 0.5
) )
submob.append_points(new_subpath) submob.append_points(new_subpath)
submob.refresh_triangulation()
return self return self
def make_smooth(self): def make_smooth(self):
@ -700,6 +703,7 @@ class VMobject(Mobject):
if new_path_point: if new_path_point:
self.append_points([new_path_point]) self.append_points([new_path_point])
self.refresh_triangulation()
return self return self
def insert_n_curves_to_point_list(self, n, points): def insert_n_curves_to_point_list(self, n, points):
@ -772,6 +776,8 @@ class VMobject(Mobject):
if alpha == 1.0: if alpha == 1.0:
setattr(self, attr, getattr(mobject2, attr)) setattr(self, attr, getattr(mobject2, attr))
# TODO, somehow do this using stroke_width changes
# so as to not have to change the point list
def pointwise_become_partial(self, vmobject, a, b): def pointwise_become_partial(self, vmobject, a, b):
assert(isinstance(vmobject, VMobject)) assert(isinstance(vmobject, VMobject))
# Partial curve includes three portions: # Partial curve includes three portions:
@ -803,6 +809,7 @@ class VMobject(Mobject):
self.append_points(partial_bezier_points( self.append_points(partial_bezier_points(
bezier_tuple[upper_index], 0, upper_residue bezier_tuple[upper_index], 0, upper_residue
)) ))
self.refresh_triangulation()
return self return self
def get_subcurve(self, a, b): def get_subcurve(self, a, b):
@ -816,27 +823,50 @@ class VMobject(Mobject):
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype) self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
def get_shader_info_list(self): def get_shader_info_list(self):
stroke_info = get_shader_info(
vert_file=self.stroke_vert_shader_file,
geom_file=self.stroke_geom_shader_file,
frag_file=self.stroke_frag_shader_file,
texture_path=self.texture_path,
render_primative=self.render_primative,
)
fill_info = get_shader_info(
vert_file=self.fill_vert_shader_file,
geom_file=self.fill_geom_shader_file,
frag_file=self.fill_frag_shader_file,
texture_path=self.texture_path,
render_primative=self.render_primative,
)
back_stroke_data = []
stroke_data = []
fill_data = []
for submob in self.family_members_with_points():
stroke_width = submob.get_stroke_width()
stroke_opacity = submob.get_stroke_opacity()
fill_opacity = submob.get_fill_opacity()
if fill_opacity > 0:
fill_data.append(submob.get_fill_shader_data())
if stroke_width > 0 and stroke_opacity > 0:
if submob.draw_stroke_behind_fill:
data = back_stroke_data
else:
data = stroke_data
data.append(submob.get_stroke_shader_data())
result = [] result = []
if self.get_fill_opacity() > 0: if back_stroke_data:
result.append({ back_stroke_info = dict(stroke_info) # Copy
"data": self.get_fill_shader_data(), back_stroke_info["data"] = np.hstack(back_stroke_data)
"vert": self.fill_vert_shader_file, result.append(back_stroke_info)
"geom": self.fill_geom_shader_file, if fill_data:
"frag": self.fill_frag_shader_file, fill_info["data"] = np.hstack(fill_data)
"render_primative": self.render_primative, result.append(fill_info)
"texture_path": self.texture_path, if stroke_data:
}) stroke_info["data"] = np.hstack(stroke_data)
if self.get_stroke_width() > 0 and self.get_stroke_opacity() > 0: result.append(stroke_info)
result.append({
"data": self.get_stroke_shader_data(),
"vert": self.stroke_vert_shader_file,
"geom": self.stroke_geom_shader_file,
"frag": self.stroke_frag_shader_file,
"render_primative": self.render_primative,
"texture_path": self.texture_path,
})
if len(result) == 2 and self.draw_stroke_behind_fill:
return [result[1], result[0]]
return result return result
def get_stroke_shader_data(self): def get_stroke_shader_data(self):
@ -878,6 +908,10 @@ class VMobject(Mobject):
for sm in self.family_members_with_points(): for sm in self.family_members_with_points():
sm.triangulation_locked = False sm.triangulation_locked = False
def refresh_triangulation(self):
if self.triangulation_locked:
self.lock_triangulation()
def get_signed_polygonal_area(self): def get_signed_polygonal_area(self):
nppc = self.n_points_per_curve nppc = self.n_points_per_curve
p0 = self.points[0::nppc] p0 = self.points[0::nppc]

View file

@ -120,7 +120,6 @@ class Scene(Container):
def print_end_message(self): def print_end_message(self):
print(f"Played {self.num_plays} animations") print(f"Played {self.num_plays} animations")
# TODO, remove this
def set_variables_as_attrs(self, *objects, **newly_named_objects): def set_variables_as_attrs(self, *objects, **newly_named_objects):
""" """
This method is slightly hacky, making it a little easier This method is slightly hacky, making it a little easier
@ -145,21 +144,21 @@ class Scene(Container):
def update_frame(self, dt=0, ignore_skipping=False): def update_frame(self, dt=0, ignore_skipping=False):
self.increment_time(dt) self.increment_time(dt)
self.update_mobjects(dt) # Should skip? self.update_mobjects(dt)
if self.skip_animations and not ignore_skipping: if self.skip_animations and not ignore_skipping:
return return
if self.window: if self.window:
self.window.clear() self.window.clear()
self.camera.clear() self.camera.clear()
self.camera.capture(*self.get_displayed_mobjects()) self.camera.capture(*self.mobjects)
if self.window: if self.window:
self.window.swap_buffers() self.window.swap_buffers()
win_time, win_dt = self.window.timer.next_frame() win_time, win_dt = self.window.timer.next_frame()
while (self.time - self.skip_time - win_time) > 0: while (self.time - self.skip_time - win_time) > 0:
self.window.clear() self.window.clear()
self.camera.capture(*self.get_displayed_mobjects()) self.camera.capture(*self.mobjects)
self.window.swap_buffers() self.window.swap_buffers()
win_time, win_dt = self.window.timer.next_frame() win_time, win_dt = self.window.timer.next_frame()
@ -188,12 +187,6 @@ class Scene(Container):
self.time += dt self.time += dt
### ###
def get_displayed_mobjects(self):
return it.chain(*[
mob.family_members_with_points()
for mob in self.mobjects
])
def get_top_level_mobjects(self): def get_top_level_mobjects(self):
# Return only those which are not in the family # Return only those which are not in the family
# of another mobject from the scene # of another mobject from the scene
@ -278,7 +271,8 @@ class Scene(Container):
step = 1 / self.camera.frame_rate step = 1 / self.camera.frame_rate
times = np.arange(0, run_time, step) times = np.arange(0, run_time, step)
time_progression = ProgressDisplay( time_progression = ProgressDisplay(
times, total=n_iterations, times,
total=n_iterations,
leave=self.leave_progress_bars, leave=self.leave_progress_bars,
ascii=False if platform.system() != 'Windows' else True ascii=False if platform.system() != 'Windows' else True
) )
@ -291,9 +285,8 @@ class Scene(Container):
run_time = self.get_run_time(animations) run_time = self.get_run_time(animations)
time_progression = self.get_time_progression(run_time) time_progression = self.get_time_progression(run_time)
time_progression.set_description("".join([ time_progression.set_description("".join([
"Animation {}: ".format(self.num_plays), f"Animation {self.num_plays}: {animations[0]}",
str(animations[0]), ", etc." if len(animations) > 1 else "",
(", etc." if len(animations) > 1 else ""),
])) ]))
return time_progression return time_progression
@ -403,9 +396,6 @@ class Scene(Container):
curr_mobjects += mob.get_family() curr_mobjects += mob.get_family()
def progress_through_animations(self, animations): def progress_through_animations(self, animations):
# Paint all non-moving objects onto the screen, so they don't
# have to be rendered every frame
# moving_mobjects = self.get_moving_mobjects(*animations)
last_t = 0 last_t = 0
for t in self.get_animation_time_progression(animations): for t in self.get_animation_time_progression(animations):
dt = t - last_t dt = t - last_t

104
manimlib/utils/shaders.py Normal file
View file

@ -0,0 +1,104 @@
import os
import warnings
import re
import moderngl
from manimlib.constants import SHADER_DIR
# Mobjects that should be rendered with
# the same shader will be organized and
# clumped together based on keeping track
# of a dict holding all the relevant information
# to that shader
SHADER_INFO_KEYS = [
"data",
"vert",
"geom",
"frag",
"texture_path",
"render_primative",
]
def get_shader_info(data=None,
vert_file=None,
geom_file=None,
frag_file=None,
texture_path=None,
render_primative=moderngl.TRIANGLE_STRIP):
return {
key: value
for key, value in zip(
SHADER_INFO_KEYS,
[
data, vert_file, geom_file, frag_file,
texture_path, str(render_primative)
]
)
}
def is_valid_shader_info(shader_info):
data = shader_info["data"]
return all([
data is not None and len(data) > 0,
shader_info["vert"],
shader_info["frag"],
])
def shader_info_to_id(shader_info):
# A unique id for a shader based on the
# files holding its code and texture
return "|".join([
shader_info.get(key, "") or ""
for key in SHADER_INFO_KEYS[1:]
])
def shader_id_to_info(sid):
return {
key: (value or None)
for key, value in zip(
SHADER_INFO_KEYS,
[None, *sid.split("|")]
)
}
def same_shader_type(info1, info2):
return all([
info1[key] == info2[key]
for key in [
"vert",
"geom",
"frag",
"texture_path",
"render_primative",
]
])
def get_shader_code_from_file(filename):
if not filename:
return None
filepath = os.path.join(SHADER_DIR, filename)
if not os.path.exists(filepath):
warnings.warn(f"No file at {filepath}")
return
with open(filepath, "r") as f:
result = f.read()
# To share functionality between shaders, some functions are read in
# from other files an inserted into the relevant strings before
# passing to ctx.program for compiling
# Replace "#INSERT " lines with relevant code
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
for line in insertions:
inserted_code = get_shader_code_from_file(line.replace("#INSERT ", ""))
result = result.replace(line, inserted_code)
return result