mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Moved around where batching mobjects by shader type occurs, while also pulling out some of the shader_id helper functions
This commit is contained in:
parent
673b85f129
commit
9d4b16d03f
4 changed files with 179 additions and 80 deletions
|
@ -13,6 +13,9 @@ from manimlib.mobject.mobject import Mobject
|
|||
from manimlib.utils.config_ops import digest_config
|
||||
from manimlib.utils.iterables import batch_by_property
|
||||
from manimlib.utils.simple_functions import fdiv
|
||||
from manimlib.utils.shaders import shader_info_to_id
|
||||
from manimlib.utils.shaders import shader_id_to_info
|
||||
from manimlib.utils.shaders import get_shader_code_from_file
|
||||
|
||||
|
||||
# TODO, think about how to incorporate perspective,
|
||||
|
@ -209,17 +212,18 @@ class Camera(object):
|
|||
|
||||
# Rendering
|
||||
def capture(self, *mobjects, **kwargs):
|
||||
shader_infos = list(it.chain(*[
|
||||
shader_infos = it.chain(*[
|
||||
mob.get_shader_info_list()
|
||||
for mob in mobjects
|
||||
]))
|
||||
])
|
||||
# TODO, batching works well when the mobjects are already organized,
|
||||
# but can we somehow use z-buffering to better effect here?
|
||||
batches = batch_by_property(shader_infos, self.get_shader_id)
|
||||
batches = batch_by_property(shader_infos, shader_info_to_id)
|
||||
|
||||
for info_group, sid in batches:
|
||||
shader = self.get_shader(sid)
|
||||
data = np.hstack([info["data"] for info in info_group])
|
||||
render_primative = info_group[0]["render_primative"]
|
||||
render_primative = int(info_group[0]["render_primative"])
|
||||
self.render_from_shader(shader, data, render_primative)
|
||||
|
||||
# Shaders
|
||||
|
@ -227,57 +231,24 @@ class Camera(object):
|
|||
# Initialize with the null id going to None
|
||||
self.id_to_shader = {"": None}
|
||||
|
||||
def get_shader_id(self, shader_info):
|
||||
# A unique id for a shader based on the names of the files holding its code
|
||||
vert, geom, frag, text = [
|
||||
shader_info.get(key, "") or ""
|
||||
for key in ["vert", "geom", "frag", "texture_path"]
|
||||
]
|
||||
if not vert or not frag:
|
||||
# Not an actual shader
|
||||
return ""
|
||||
return "|".join([vert, geom, frag, text])
|
||||
|
||||
def get_shader(self, sid):
|
||||
if sid not in self.id_to_shader:
|
||||
vert, geom, frag, text = sid.split("|")
|
||||
info = shader_id_to_info(sid)
|
||||
shader = self.ctx.program(
|
||||
vertex_shader=self.get_shader_code_from_file(vert),
|
||||
geometry_shader=self.get_shader_code_from_file(geom),
|
||||
fragment_shader=self.get_shader_code_from_file(frag),
|
||||
vertex_shader=get_shader_code_from_file(info["vert"]),
|
||||
geometry_shader=get_shader_code_from_file(info["geom"]),
|
||||
fragment_shader=get_shader_code_from_file(info["frag"]),
|
||||
)
|
||||
if text:
|
||||
if info["texture_path"]:
|
||||
# TODO, this currently assumes that the uniform Sampler2D
|
||||
# is named Texture
|
||||
tid = self.get_texture_id(text)
|
||||
tid = self.get_texture_id(info["texture_path"])
|
||||
shader["Texture"].value = tid
|
||||
|
||||
self.set_shader_uniforms(shader)
|
||||
self.id_to_shader[sid] = shader
|
||||
return self.id_to_shader[sid]
|
||||
|
||||
def get_shader_code_from_file(self, filename):
|
||||
if len(filename) == 0:
|
||||
return None
|
||||
|
||||
filepath = os.path.join(SHADER_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
warnings.warn(f"No file at {file_path}")
|
||||
return
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
result = f.read()
|
||||
|
||||
# To share functionality between shaders, some functions are read in
|
||||
# from other files an inserted into the relevant strings before
|
||||
# passing to ctx.program for compiling
|
||||
# Replace "#INSERT " lines with relevant code
|
||||
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
|
||||
for line in insertions:
|
||||
inserted_code = self.get_shader_code_from_file(line.replace("#INSERT ", ""))
|
||||
result = result.replace(line, inserted_code)
|
||||
return result
|
||||
|
||||
def set_shader_uniforms(self, shader):
|
||||
if shader is None:
|
||||
return
|
||||
|
|
|
@ -23,6 +23,7 @@ from manimlib.utils.space_ops import cross2d
|
|||
from manimlib.utils.space_ops import get_norm
|
||||
from manimlib.utils.space_ops import angle_between_vectors
|
||||
from manimlib.utils.space_ops import earclip_triangulation
|
||||
from manimlib.utils.shaders import get_shader_info
|
||||
|
||||
|
||||
class VMobject(Mobject):
|
||||
|
@ -368,6 +369,7 @@ class VMobject(Mobject):
|
|||
# Points
|
||||
def set_points(self, points):
|
||||
self.points = np.array(points)
|
||||
self.refresh_triangulation()
|
||||
return self
|
||||
|
||||
def get_points(self):
|
||||
|
@ -525,6 +527,7 @@ class VMobject(Mobject):
|
|||
anchors[:-1], anchors[1:], 0.5
|
||||
)
|
||||
submob.append_points(new_subpath)
|
||||
submob.refresh_triangulation()
|
||||
return self
|
||||
|
||||
def make_smooth(self):
|
||||
|
@ -700,6 +703,7 @@ class VMobject(Mobject):
|
|||
|
||||
if new_path_point:
|
||||
self.append_points([new_path_point])
|
||||
self.refresh_triangulation()
|
||||
return self
|
||||
|
||||
def insert_n_curves_to_point_list(self, n, points):
|
||||
|
@ -772,6 +776,8 @@ class VMobject(Mobject):
|
|||
if alpha == 1.0:
|
||||
setattr(self, attr, getattr(mobject2, attr))
|
||||
|
||||
# TODO, somehow do this using stroke_width changes
|
||||
# so as to not have to change the point list
|
||||
def pointwise_become_partial(self, vmobject, a, b):
|
||||
assert(isinstance(vmobject, VMobject))
|
||||
# Partial curve includes three portions:
|
||||
|
@ -803,6 +809,7 @@ class VMobject(Mobject):
|
|||
self.append_points(partial_bezier_points(
|
||||
bezier_tuple[upper_index], 0, upper_residue
|
||||
))
|
||||
self.refresh_triangulation()
|
||||
return self
|
||||
|
||||
def get_subcurve(self, a, b):
|
||||
|
@ -816,27 +823,50 @@ class VMobject(Mobject):
|
|||
self.stroke_data = np.zeros(len(self.points), dtype=self.stroke_dtype)
|
||||
|
||||
def get_shader_info_list(self):
|
||||
stroke_info = get_shader_info(
|
||||
vert_file=self.stroke_vert_shader_file,
|
||||
geom_file=self.stroke_geom_shader_file,
|
||||
frag_file=self.stroke_frag_shader_file,
|
||||
texture_path=self.texture_path,
|
||||
render_primative=self.render_primative,
|
||||
)
|
||||
fill_info = get_shader_info(
|
||||
vert_file=self.fill_vert_shader_file,
|
||||
geom_file=self.fill_geom_shader_file,
|
||||
frag_file=self.fill_frag_shader_file,
|
||||
texture_path=self.texture_path,
|
||||
render_primative=self.render_primative,
|
||||
)
|
||||
|
||||
back_stroke_data = []
|
||||
stroke_data = []
|
||||
fill_data = []
|
||||
for submob in self.family_members_with_points():
|
||||
stroke_width = submob.get_stroke_width()
|
||||
stroke_opacity = submob.get_stroke_opacity()
|
||||
fill_opacity = submob.get_fill_opacity()
|
||||
|
||||
if fill_opacity > 0:
|
||||
fill_data.append(submob.get_fill_shader_data())
|
||||
|
||||
if stroke_width > 0 and stroke_opacity > 0:
|
||||
if submob.draw_stroke_behind_fill:
|
||||
data = back_stroke_data
|
||||
else:
|
||||
data = stroke_data
|
||||
data.append(submob.get_stroke_shader_data())
|
||||
|
||||
result = []
|
||||
if self.get_fill_opacity() > 0:
|
||||
result.append({
|
||||
"data": self.get_fill_shader_data(),
|
||||
"vert": self.fill_vert_shader_file,
|
||||
"geom": self.fill_geom_shader_file,
|
||||
"frag": self.fill_frag_shader_file,
|
||||
"render_primative": self.render_primative,
|
||||
"texture_path": self.texture_path,
|
||||
})
|
||||
if self.get_stroke_width() > 0 and self.get_stroke_opacity() > 0:
|
||||
result.append({
|
||||
"data": self.get_stroke_shader_data(),
|
||||
"vert": self.stroke_vert_shader_file,
|
||||
"geom": self.stroke_geom_shader_file,
|
||||
"frag": self.stroke_frag_shader_file,
|
||||
"render_primative": self.render_primative,
|
||||
"texture_path": self.texture_path,
|
||||
})
|
||||
if len(result) == 2 and self.draw_stroke_behind_fill:
|
||||
return [result[1], result[0]]
|
||||
if back_stroke_data:
|
||||
back_stroke_info = dict(stroke_info) # Copy
|
||||
back_stroke_info["data"] = np.hstack(back_stroke_data)
|
||||
result.append(back_stroke_info)
|
||||
if fill_data:
|
||||
fill_info["data"] = np.hstack(fill_data)
|
||||
result.append(fill_info)
|
||||
if stroke_data:
|
||||
stroke_info["data"] = np.hstack(stroke_data)
|
||||
result.append(stroke_info)
|
||||
return result
|
||||
|
||||
def get_stroke_shader_data(self):
|
||||
|
@ -878,6 +908,10 @@ class VMobject(Mobject):
|
|||
for sm in self.family_members_with_points():
|
||||
sm.triangulation_locked = False
|
||||
|
||||
def refresh_triangulation(self):
|
||||
if self.triangulation_locked:
|
||||
self.lock_triangulation()
|
||||
|
||||
def get_signed_polygonal_area(self):
|
||||
nppc = self.n_points_per_curve
|
||||
p0 = self.points[0::nppc]
|
||||
|
|
|
@ -120,7 +120,6 @@ class Scene(Container):
|
|||
def print_end_message(self):
|
||||
print(f"Played {self.num_plays} animations")
|
||||
|
||||
# TODO, remove this
|
||||
def set_variables_as_attrs(self, *objects, **newly_named_objects):
|
||||
"""
|
||||
This method is slightly hacky, making it a little easier
|
||||
|
@ -145,21 +144,21 @@ class Scene(Container):
|
|||
|
||||
def update_frame(self, dt=0, ignore_skipping=False):
|
||||
self.increment_time(dt)
|
||||
self.update_mobjects(dt) # Should skip?
|
||||
self.update_mobjects(dt)
|
||||
if self.skip_animations and not ignore_skipping:
|
||||
return
|
||||
|
||||
if self.window:
|
||||
self.window.clear()
|
||||
self.camera.clear()
|
||||
self.camera.capture(*self.get_displayed_mobjects())
|
||||
self.camera.capture(*self.mobjects)
|
||||
|
||||
if self.window:
|
||||
self.window.swap_buffers()
|
||||
win_time, win_dt = self.window.timer.next_frame()
|
||||
while (self.time - self.skip_time - win_time) > 0:
|
||||
self.window.clear()
|
||||
self.camera.capture(*self.get_displayed_mobjects())
|
||||
self.camera.capture(*self.mobjects)
|
||||
self.window.swap_buffers()
|
||||
win_time, win_dt = self.window.timer.next_frame()
|
||||
|
||||
|
@ -188,12 +187,6 @@ class Scene(Container):
|
|||
self.time += dt
|
||||
|
||||
###
|
||||
def get_displayed_mobjects(self):
|
||||
return it.chain(*[
|
||||
mob.family_members_with_points()
|
||||
for mob in self.mobjects
|
||||
])
|
||||
|
||||
def get_top_level_mobjects(self):
|
||||
# Return only those which are not in the family
|
||||
# of another mobject from the scene
|
||||
|
@ -278,7 +271,8 @@ class Scene(Container):
|
|||
step = 1 / self.camera.frame_rate
|
||||
times = np.arange(0, run_time, step)
|
||||
time_progression = ProgressDisplay(
|
||||
times, total=n_iterations,
|
||||
times,
|
||||
total=n_iterations,
|
||||
leave=self.leave_progress_bars,
|
||||
ascii=False if platform.system() != 'Windows' else True
|
||||
)
|
||||
|
@ -291,9 +285,8 @@ class Scene(Container):
|
|||
run_time = self.get_run_time(animations)
|
||||
time_progression = self.get_time_progression(run_time)
|
||||
time_progression.set_description("".join([
|
||||
"Animation {}: ".format(self.num_plays),
|
||||
str(animations[0]),
|
||||
(", etc." if len(animations) > 1 else ""),
|
||||
f"Animation {self.num_plays}: {animations[0]}",
|
||||
", etc." if len(animations) > 1 else "",
|
||||
]))
|
||||
return time_progression
|
||||
|
||||
|
@ -403,9 +396,6 @@ class Scene(Container):
|
|||
curr_mobjects += mob.get_family()
|
||||
|
||||
def progress_through_animations(self, animations):
|
||||
# Paint all non-moving objects onto the screen, so they don't
|
||||
# have to be rendered every frame
|
||||
# moving_mobjects = self.get_moving_mobjects(*animations)
|
||||
last_t = 0
|
||||
for t in self.get_animation_time_progression(animations):
|
||||
dt = t - last_t
|
||||
|
|
104
manimlib/utils/shaders.py
Normal file
104
manimlib/utils/shaders.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import warnings
|
||||
import re
|
||||
import moderngl
|
||||
|
||||
from manimlib.constants import SHADER_DIR
|
||||
|
||||
# Mobjects that should be rendered with
|
||||
# the same shader will be organized and
|
||||
# clumped together based on keeping track
|
||||
# of a dict holding all the relevant information
|
||||
# to that shader
|
||||
|
||||
|
||||
SHADER_INFO_KEYS = [
|
||||
"data",
|
||||
"vert",
|
||||
"geom",
|
||||
"frag",
|
||||
"texture_path",
|
||||
"render_primative",
|
||||
]
|
||||
|
||||
|
||||
def get_shader_info(data=None,
|
||||
vert_file=None,
|
||||
geom_file=None,
|
||||
frag_file=None,
|
||||
texture_path=None,
|
||||
render_primative=moderngl.TRIANGLE_STRIP):
|
||||
return {
|
||||
key: value
|
||||
for key, value in zip(
|
||||
SHADER_INFO_KEYS,
|
||||
[
|
||||
data, vert_file, geom_file, frag_file,
|
||||
texture_path, str(render_primative)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def is_valid_shader_info(shader_info):
|
||||
data = shader_info["data"]
|
||||
return all([
|
||||
data is not None and len(data) > 0,
|
||||
shader_info["vert"],
|
||||
shader_info["frag"],
|
||||
])
|
||||
|
||||
|
||||
def shader_info_to_id(shader_info):
|
||||
# A unique id for a shader based on the
|
||||
# files holding its code and texture
|
||||
return "|".join([
|
||||
shader_info.get(key, "") or ""
|
||||
for key in SHADER_INFO_KEYS[1:]
|
||||
])
|
||||
|
||||
|
||||
def shader_id_to_info(sid):
|
||||
return {
|
||||
key: (value or None)
|
||||
for key, value in zip(
|
||||
SHADER_INFO_KEYS,
|
||||
[None, *sid.split("|")]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def same_shader_type(info1, info2):
|
||||
return all([
|
||||
info1[key] == info2[key]
|
||||
for key in [
|
||||
"vert",
|
||||
"geom",
|
||||
"frag",
|
||||
"texture_path",
|
||||
"render_primative",
|
||||
]
|
||||
])
|
||||
|
||||
|
||||
def get_shader_code_from_file(filename):
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
filepath = os.path.join(SHADER_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
warnings.warn(f"No file at {filepath}")
|
||||
return
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
result = f.read()
|
||||
|
||||
# To share functionality between shaders, some functions are read in
|
||||
# from other files an inserted into the relevant strings before
|
||||
# passing to ctx.program for compiling
|
||||
# Replace "#INSERT " lines with relevant code
|
||||
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
|
||||
for line in insertions:
|
||||
inserted_code = get_shader_code_from_file(line.replace("#INSERT ", ""))
|
||||
result = result.replace(line, inserted_code)
|
||||
return result
|
Loading…
Add table
Reference in a new issue