mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Moved around where batching mobjects by shader type occurs, while also pulling out some of the shader_id helper functions
This commit is contained in:
parent
673b85f129
commit
9d4b16d03f
4 changed files with 179 additions and 80 deletions
|
@ -13,6 +13,9 @@ from manimlib.mobject.mobject import Mobject
|
||||||
from manimlib.utils.config_ops import digest_config
|
from manimlib.utils.config_ops import digest_config
|
||||||
from manimlib.utils.iterables import batch_by_property
|
from manimlib.utils.iterables import batch_by_property
|
||||||
from manimlib.utils.simple_functions import fdiv
|
from manimlib.utils.simple_functions import fdiv
|
||||||
|
from manimlib.utils.shaders import shader_info_to_id
|
||||||
|
from manimlib.utils.shaders import shader_id_to_info
|
||||||
|
from manimlib.utils.shaders import get_shader_code_from_file
|
||||||
|
|
||||||
|
|
||||||
# TODO, think about how to incorporate perspective,
|
# TODO, think about how to incorporate perspective,
|
||||||
|
@ -209,17 +212,18 @@ class Camera(object):
|
||||||
|
|
||||||
# Rendering
|
# Rendering
|
||||||
def capture(self, *mobjects, **kwargs):
|
def capture(self, *mobjects, **kwargs):
|
||||||
shader_infos = list(it.chain(*[
|
shader_infos = it.chain(*[
|
||||||
mob.get_shader_info_list()
|
mob.get_shader_info_list()
|
||||||
for mob in mobjects
|
for mob in mobjects
|
||||||
]))
|
])
|
||||||
# TODO, batching works well when the mobjects are already organized,
|
# TODO, batching works well when the mobjects are already organized,
|
||||||
# but can we somehow use z-buffering to better effect here?
|
# but can we somehow use z-buffering to better effect here?
|
||||||
batches = batch_by_property(shader_infos, self.get_shader_id)
|
batches = batch_by_property(shader_infos, shader_info_to_id)
|
||||||
|
|
||||||
for info_group, sid in batches:
|
for info_group, sid in batches:
|
||||||
shader = self.get_shader(sid)
|
shader = self.get_shader(sid)
|
||||||
data = np.hstack([info["data"] for info in info_group])
|
data = np.hstack([info["data"] for info in info_group])
|
||||||
render_primative = info_group[0]["render_primative"]
|
render_primative = int(info_group[0]["render_primative"])
|
||||||
self.render_from_shader(shader, data, render_primative)
|
self.render_from_shader(shader, data, render_primative)
|
||||||
|
|
||||||
# Shaders
|
# Shaders
|
||||||
|
@ -227,57 +231,24 @@ class Camera(object):
|
||||||
# Initialize with the null id going to None
|
# Initialize with the null id going to None
|
||||||
self.id_to_shader = {"": None}
|
self.id_to_shader = {"": None}
|
||||||
|
|
||||||
def get_shader_id(self, shader_info):
|
|
||||||
# A unique id for a shader based on the names of the files holding its code
|
|
||||||
vert, geom, frag, text = [
|
|
||||||
shader_info.get(key, "") or ""
|
|
||||||
for key in ["vert", "geom", "frag", "texture_path"]
|
|
||||||
]
|
|
||||||
if not vert or not frag:
|
|
||||||
# Not an actual shader
|
|
||||||
return ""
|
|
||||||
return "|".join([vert, geom, frag, text])
|
|
||||||
|
|
||||||
def get_shader(self, sid):
|
def get_shader(self, sid):
|
||||||
if sid not in self.id_to_shader:
|
if sid not in self.id_to_shader:
|
||||||
vert, geom, frag, text = sid.split("|")
|
info = shader_id_to_info(sid)
|
||||||
shader = self.ctx.program(
|
shader = self.ctx.program(
|
||||||
vertex_shader=self.get_shader_code_from_file(vert),
|
vertex_shader=get_shader_code_from_file(info["vert"]),
|
||||||
geometry_shader=self.get_shader_code_from_file(geom),
|
geometry_shader=get_shader_code_from_file(info["geom"]),
|
||||||
fragment_shader=self.get_shader_code_from_file(frag),
|
fragment_shader=get_shader_code_from_file(info["frag"]),
|
||||||
)
|
)
|
||||||
if text:
|
if info["texture_path"]:
|
||||||
# TODO, this currently assumes that the uniform Sampler2D
|
# TODO, this currently assumes that the uniform Sampler2D
|
||||||
# is named Texture
|
# is named Texture
|
||||||
tid = self.get_texture_id(text)
|
tid = self.get_texture_id(info["texture_path"])
|
||||||
shader["Texture"].value = tid
|
shader["Texture"].value = tid
|
||||||
|
|
||||||
self.set_shader_uniforms(shader)
|
self.set_shader_uniforms(shader)
|
||||||
self.id_to_shader[sid] = shader
|
self.id_to_shader[sid] = shader
|
||||||
return self.id_to_shader[sid]
|
return self.id_to_shader[sid]
|
||||||
|
|
||||||
def get_shader_code_from_file(self, filename):
|
|
||||||
if len(filename) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
filepath = os.path.join(SHADER_DIR, filename)
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
warnings.warn(f"No file at {file_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(filepath, "r") as f:
|
|
||||||
result = f.read()
|
|
||||||
|
|
||||||
# To share functionality between shaders, some functions are read in
|
|
||||||
# from other files an inserted into the relevant strings before
|
|
||||||
# passing to ctx.program for compiling
|
|
||||||
# Replace "#INSERT " lines with relevant code
|
|
||||||
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
|
|
||||||
for line in insertions:
|
|
||||||
inserted_code = self.get_shader_code_from_file(line.replace("#INSERT ", ""))
|
|
||||||
result = result.replace(line, inserted_code)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def set_shader_uniforms(self, shader):
|
def set_shader_uniforms(self, shader):
|
||||||
if shader is None:
|
if shader is None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,6 +23,7 @@ from manimlib.utils.space_ops import cross2d
|
||||||
from manimlib.utils.space_ops import get_norm
|
from manimlib.utils.space_ops import get_norm
|
||||||
from manimlib.utils.space_ops import angle_between_vectors
|
from manimlib.utils.space_ops import angle_between_vectors
|
||||||
from manimlib.utils.space_ops import earclip_triangulation
|
from manimlib.utils.space_ops import earclip_triangulation
|
||||||
|
from manimlib.utils.shaders import get_shader_info
|
||||||
|
|
||||||
|
|
||||||
class VMobject(Mobject):
|
class VMobject(Mobject):
|
||||||
|
@ -368,6 +369,7 @@ class VMobject(Mobject):
|
||||||
# Points
|
# Points
|
||||||
def set_points(self, points):
|
def set_points(self, points):
|
||||||
self.points = np.array(points)
|
self.points = np.array(points)
|
||||||
|
self.refresh_triangulation()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_points(self):
|
def get_points(self):
|
||||||
|
@ -525,6 +527,7 @@ class VMobject(Mobject):
|
||||||
anchors[:-1], anchors[1:], 0.5
|
anchors[:-1], anchors[1:], 0.5
|
||||||
)
|
)
|
||||||
submob.append_points(new_subpath)
|
submob.append_points(new_subpath)
|
||||||
|
submob.refresh_triangulation()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def make_smooth(self):
|
def make_smooth(self):
|
||||||
|
@ -700,6 +703,7 @@ class VMobject(Mobject):
|
||||||
|
|
||||||
if new_path_point:
|
if new_path_point:
|
||||||
self.append_points([new_path_point])
|
self.append_points([new_path_point])
|
||||||
|
self.refresh_triangulation()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def insert_n_curves_to_point_list(self, n, points):
|
def insert_n_curves_to_point_list(self, n, points):
|
||||||
|
@ -772,6 +776,8 @@ class VMobject(Mobject):
|
||||||
if alpha == 1.0:
|
if alpha == 1.0:
|
||||||
setattr(self, attr, getattr(mobject2, attr))
|
setattr(self, attr, getattr(mobject2, attr))
|
||||||
|
|
||||||
|
# TODO, somehow do this using stroke_width changes
|
||||||
|
# so as to not have to change the point list
|
||||||
def pointwise_become_partial(self, vmobject, a, b):
|
def pointwise_become_partial(self, vmobject, a, b):
|
||||||
assert(isinstance(vmobject, VMobject))
|
assert(isinstance(vmobject, VMobject))
|
||||||
# Partial curve includes three portions:
|
# Partial curve includes three portions:
|
||||||
|
@ -803,6 +809,7 @@ class VMobject(Mobject):
|
||||||
self.append_points(partial_bezier_points(
|
self.append_points(partial_bezier_points(
|
||||||
bezier_tuple[upper_index], 0, upper_residue
|
bezier_tuple[upper_index], 0, upper_residue
|
||||||
))
|
))
|
||||||
|
self.refresh_triangulation()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_subcurve(self, a, b):
|
def get_subcurve(self, a, b):
|
||||||
|
@ -816,27 +823,50 @@ class VMobject(Mobject):
|
||||||
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
|
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
|
||||||
|
|
||||||
def get_shader_info_list(self):
|
def get_shader_info_list(self):
|
||||||
|
stroke_info = get_shader_info(
|
||||||
|
vert_file=self.stroke_vert_shader_file,
|
||||||
|
geom_file=self.stroke_geom_shader_file,
|
||||||
|
frag_file=self.stroke_frag_shader_file,
|
||||||
|
texture_path=self.texture_path,
|
||||||
|
render_primative=self.render_primative,
|
||||||
|
)
|
||||||
|
fill_info = get_shader_info(
|
||||||
|
vert_file=self.fill_vert_shader_file,
|
||||||
|
geom_file=self.fill_geom_shader_file,
|
||||||
|
frag_file=self.fill_frag_shader_file,
|
||||||
|
texture_path=self.texture_path,
|
||||||
|
render_primative=self.render_primative,
|
||||||
|
)
|
||||||
|
|
||||||
|
back_stroke_data = []
|
||||||
|
stroke_data = []
|
||||||
|
fill_data = []
|
||||||
|
for submob in self.family_members_with_points():
|
||||||
|
stroke_width = submob.get_stroke_width()
|
||||||
|
stroke_opacity = submob.get_stroke_opacity()
|
||||||
|
fill_opacity = submob.get_fill_opacity()
|
||||||
|
|
||||||
|
if fill_opacity > 0:
|
||||||
|
fill_data.append(submob.get_fill_shader_data())
|
||||||
|
|
||||||
|
if stroke_width > 0 and stroke_opacity > 0:
|
||||||
|
if submob.draw_stroke_behind_fill:
|
||||||
|
data = back_stroke_data
|
||||||
|
else:
|
||||||
|
data = stroke_data
|
||||||
|
data.append(submob.get_stroke_shader_data())
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
if self.get_fill_opacity() > 0:
|
if back_stroke_data:
|
||||||
result.append({
|
back_stroke_info = dict(stroke_info) # Copy
|
||||||
"data": self.get_fill_shader_data(),
|
back_stroke_info["data"] = np.hstack(back_stroke_data)
|
||||||
"vert": self.fill_vert_shader_file,
|
result.append(back_stroke_info)
|
||||||
"geom": self.fill_geom_shader_file,
|
if fill_data:
|
||||||
"frag": self.fill_frag_shader_file,
|
fill_info["data"] = np.hstack(fill_data)
|
||||||
"render_primative": self.render_primative,
|
result.append(fill_info)
|
||||||
"texture_path": self.texture_path,
|
if stroke_data:
|
||||||
})
|
stroke_info["data"] = np.hstack(stroke_data)
|
||||||
if self.get_stroke_width() > 0 and self.get_stroke_opacity() > 0:
|
result.append(stroke_info)
|
||||||
result.append({
|
|
||||||
"data": self.get_stroke_shader_data(),
|
|
||||||
"vert": self.stroke_vert_shader_file,
|
|
||||||
"geom": self.stroke_geom_shader_file,
|
|
||||||
"frag": self.stroke_frag_shader_file,
|
|
||||||
"render_primative": self.render_primative,
|
|
||||||
"texture_path": self.texture_path,
|
|
||||||
})
|
|
||||||
if len(result) == 2 and self.draw_stroke_behind_fill:
|
|
||||||
return [result[1], result[0]]
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_stroke_shader_data(self):
|
def get_stroke_shader_data(self):
|
||||||
|
@ -878,6 +908,10 @@ class VMobject(Mobject):
|
||||||
for sm in self.family_members_with_points():
|
for sm in self.family_members_with_points():
|
||||||
sm.triangulation_locked = False
|
sm.triangulation_locked = False
|
||||||
|
|
||||||
|
def refresh_triangulation(self):
|
||||||
|
if self.triangulation_locked:
|
||||||
|
self.lock_triangulation()
|
||||||
|
|
||||||
def get_signed_polygonal_area(self):
|
def get_signed_polygonal_area(self):
|
||||||
nppc = self.n_points_per_curve
|
nppc = self.n_points_per_curve
|
||||||
p0 = self.points[0::nppc]
|
p0 = self.points[0::nppc]
|
||||||
|
|
|
@ -120,7 +120,6 @@ class Scene(Container):
|
||||||
def print_end_message(self):
|
def print_end_message(self):
|
||||||
print(f"Played {self.num_plays} animations")
|
print(f"Played {self.num_plays} animations")
|
||||||
|
|
||||||
# TODO, remove this
|
|
||||||
def set_variables_as_attrs(self, *objects, **newly_named_objects):
|
def set_variables_as_attrs(self, *objects, **newly_named_objects):
|
||||||
"""
|
"""
|
||||||
This method is slightly hacky, making it a little easier
|
This method is slightly hacky, making it a little easier
|
||||||
|
@ -145,21 +144,21 @@ class Scene(Container):
|
||||||
|
|
||||||
def update_frame(self, dt=0, ignore_skipping=False):
|
def update_frame(self, dt=0, ignore_skipping=False):
|
||||||
self.increment_time(dt)
|
self.increment_time(dt)
|
||||||
self.update_mobjects(dt) # Should skip?
|
self.update_mobjects(dt)
|
||||||
if self.skip_animations and not ignore_skipping:
|
if self.skip_animations and not ignore_skipping:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.window:
|
if self.window:
|
||||||
self.window.clear()
|
self.window.clear()
|
||||||
self.camera.clear()
|
self.camera.clear()
|
||||||
self.camera.capture(*self.get_displayed_mobjects())
|
self.camera.capture(*self.mobjects)
|
||||||
|
|
||||||
if self.window:
|
if self.window:
|
||||||
self.window.swap_buffers()
|
self.window.swap_buffers()
|
||||||
win_time, win_dt = self.window.timer.next_frame()
|
win_time, win_dt = self.window.timer.next_frame()
|
||||||
while (self.time - self.skip_time - win_time) > 0:
|
while (self.time - self.skip_time - win_time) > 0:
|
||||||
self.window.clear()
|
self.window.clear()
|
||||||
self.camera.capture(*self.get_displayed_mobjects())
|
self.camera.capture(*self.mobjects)
|
||||||
self.window.swap_buffers()
|
self.window.swap_buffers()
|
||||||
win_time, win_dt = self.window.timer.next_frame()
|
win_time, win_dt = self.window.timer.next_frame()
|
||||||
|
|
||||||
|
@ -188,12 +187,6 @@ class Scene(Container):
|
||||||
self.time += dt
|
self.time += dt
|
||||||
|
|
||||||
###
|
###
|
||||||
def get_displayed_mobjects(self):
|
|
||||||
return it.chain(*[
|
|
||||||
mob.family_members_with_points()
|
|
||||||
for mob in self.mobjects
|
|
||||||
])
|
|
||||||
|
|
||||||
def get_top_level_mobjects(self):
|
def get_top_level_mobjects(self):
|
||||||
# Return only those which are not in the family
|
# Return only those which are not in the family
|
||||||
# of another mobject from the scene
|
# of another mobject from the scene
|
||||||
|
@ -278,7 +271,8 @@ class Scene(Container):
|
||||||
step = 1 / self.camera.frame_rate
|
step = 1 / self.camera.frame_rate
|
||||||
times = np.arange(0, run_time, step)
|
times = np.arange(0, run_time, step)
|
||||||
time_progression = ProgressDisplay(
|
time_progression = ProgressDisplay(
|
||||||
times, total=n_iterations,
|
times,
|
||||||
|
total=n_iterations,
|
||||||
leave=self.leave_progress_bars,
|
leave=self.leave_progress_bars,
|
||||||
ascii=False if platform.system() != 'Windows' else True
|
ascii=False if platform.system() != 'Windows' else True
|
||||||
)
|
)
|
||||||
|
@ -291,9 +285,8 @@ class Scene(Container):
|
||||||
run_time = self.get_run_time(animations)
|
run_time = self.get_run_time(animations)
|
||||||
time_progression = self.get_time_progression(run_time)
|
time_progression = self.get_time_progression(run_time)
|
||||||
time_progression.set_description("".join([
|
time_progression.set_description("".join([
|
||||||
"Animation {}: ".format(self.num_plays),
|
f"Animation {self.num_plays}: {animations[0]}",
|
||||||
str(animations[0]),
|
", etc." if len(animations) > 1 else "",
|
||||||
(", etc." if len(animations) > 1 else ""),
|
|
||||||
]))
|
]))
|
||||||
return time_progression
|
return time_progression
|
||||||
|
|
||||||
|
@ -403,9 +396,6 @@ class Scene(Container):
|
||||||
curr_mobjects += mob.get_family()
|
curr_mobjects += mob.get_family()
|
||||||
|
|
||||||
def progress_through_animations(self, animations):
|
def progress_through_animations(self, animations):
|
||||||
# Paint all non-moving objects onto the screen, so they don't
|
|
||||||
# have to be rendered every frame
|
|
||||||
# moving_mobjects = self.get_moving_mobjects(*animations)
|
|
||||||
last_t = 0
|
last_t = 0
|
||||||
for t in self.get_animation_time_progression(animations):
|
for t in self.get_animation_time_progression(animations):
|
||||||
dt = t - last_t
|
dt = t - last_t
|
||||||
|
|
104
manimlib/utils/shaders.py
Normal file
104
manimlib/utils/shaders.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
import re
|
||||||
|
import moderngl
|
||||||
|
|
||||||
|
from manimlib.constants import SHADER_DIR
|
||||||
|
|
||||||
|
# Mobjects that should be rendered with
|
||||||
|
# the same shader will be organized and
|
||||||
|
# clumped together based on keeping track
|
||||||
|
# of a dict holding all the relevant information
|
||||||
|
# to that shader
|
||||||
|
|
||||||
|
|
||||||
|
SHADER_INFO_KEYS = [
|
||||||
|
"data",
|
||||||
|
"vert",
|
||||||
|
"geom",
|
||||||
|
"frag",
|
||||||
|
"texture_path",
|
||||||
|
"render_primative",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_shader_info(data=None,
|
||||||
|
vert_file=None,
|
||||||
|
geom_file=None,
|
||||||
|
frag_file=None,
|
||||||
|
texture_path=None,
|
||||||
|
render_primative=moderngl.TRIANGLE_STRIP):
|
||||||
|
return {
|
||||||
|
key: value
|
||||||
|
for key, value in zip(
|
||||||
|
SHADER_INFO_KEYS,
|
||||||
|
[
|
||||||
|
data, vert_file, geom_file, frag_file,
|
||||||
|
texture_path, str(render_primative)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_shader_info(shader_info):
|
||||||
|
data = shader_info["data"]
|
||||||
|
return all([
|
||||||
|
data is not None and len(data) > 0,
|
||||||
|
shader_info["vert"],
|
||||||
|
shader_info["frag"],
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def shader_info_to_id(shader_info):
|
||||||
|
# A unique id for a shader based on the
|
||||||
|
# files holding its code and texture
|
||||||
|
return "|".join([
|
||||||
|
shader_info.get(key, "") or ""
|
||||||
|
for key in SHADER_INFO_KEYS[1:]
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def shader_id_to_info(sid):
|
||||||
|
return {
|
||||||
|
key: (value or None)
|
||||||
|
for key, value in zip(
|
||||||
|
SHADER_INFO_KEYS,
|
||||||
|
[None, *sid.split("|")]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def same_shader_type(info1, info2):
|
||||||
|
return all([
|
||||||
|
info1[key] == info2[key]
|
||||||
|
for key in [
|
||||||
|
"vert",
|
||||||
|
"geom",
|
||||||
|
"frag",
|
||||||
|
"texture_path",
|
||||||
|
"render_primative",
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def get_shader_code_from_file(filename):
|
||||||
|
if not filename:
|
||||||
|
return None
|
||||||
|
|
||||||
|
filepath = os.path.join(SHADER_DIR, filename)
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
warnings.warn(f"No file at {filepath}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
result = f.read()
|
||||||
|
|
||||||
|
# To share functionality between shaders, some functions are read in
|
||||||
|
# from other files an inserted into the relevant strings before
|
||||||
|
# passing to ctx.program for compiling
|
||||||
|
# Replace "#INSERT " lines with relevant code
|
||||||
|
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
|
||||||
|
for line in insertions:
|
||||||
|
inserted_code = get_shader_code_from_file(line.replace("#INSERT ", ""))
|
||||||
|
result = result.replace(line, inserted_code)
|
||||||
|
return result
|
Loading…
Add table
Reference in a new issue