mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Unify get_shader_wrapper_list, and and better subdivide render groups by ShaderWrapper ids
This commit is contained in:
parent
08e33faab8
commit
f12b143d16
4 changed files with 50 additions and 65 deletions
|
@ -1944,8 +1944,9 @@ class Mobject(object):
|
|||
|
||||
@affects_data
|
||||
def replace_shader_code(self, old: str, new: str) -> Self:
|
||||
self.shader_code_replacements[old] = new
|
||||
self._shaders_initialized = False
|
||||
for mob in self.get_family():
|
||||
mob.shader_code_replacements[old] = new
|
||||
mob._shaders_initialized = False
|
||||
for mob in self.get_ancestors():
|
||||
mob._shaders_initialized = False
|
||||
return self
|
||||
|
@ -2000,48 +2001,45 @@ class Mobject(object):
|
|||
texture_paths=self.texture_paths,
|
||||
depth_test=self.depth_test,
|
||||
render_primitive=self.render_primitive,
|
||||
code_replacements=self.shader_code_replacements,
|
||||
)
|
||||
|
||||
def refresh_shader_wrapper_id(self):
|
||||
if self._shaders_initialized:
|
||||
self.shader_wrapper.refresh_id()
|
||||
for submob in self.get_family():
|
||||
if submob._shaders_initialized:
|
||||
submob.shader_wrapper.depth_test = submob.depth_test
|
||||
submob.shader_wrapper.refresh_id()
|
||||
return self
|
||||
|
||||
def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
|
||||
if not self._shaders_initialized:
|
||||
self.init_shader_data(ctx)
|
||||
self._shaders_initialized = True
|
||||
|
||||
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
|
||||
self.shader_wrapper.depth_test = self.depth_test
|
||||
for old, new in self.shader_code_replacements.items():
|
||||
self.shader_wrapper.replace_code(old, new)
|
||||
return self.shader_wrapper
|
||||
|
||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
||||
family = self.family_members_with_points()
|
||||
for submob in family:
|
||||
submob.get_shader_wrapper(ctx)
|
||||
batches = batch_by_property(family, lambda submob: submob.shader_wrapper.get_id())
|
||||
batches = batch_by_property(family, lambda sm: sm.get_shader_wrapper(ctx).get_id())
|
||||
|
||||
result = []
|
||||
for submobs, sid in batches:
|
||||
shader_wrapper = submobs[0].shader_wrapper
|
||||
data_list = [sm.get_shader_data() for sm in submobs]
|
||||
indices_list = [sm.get_shader_vert_indices() for sm in submobs]
|
||||
if indices_list[0] is None:
|
||||
indices_list = None
|
||||
shader_wrapper.read_in(data_list, indices_list)
|
||||
data_list = list(it.chain(*(sm.get_shader_data() for sm in submobs)))
|
||||
shader_wrapper.read_in(data_list, indices_list=None)
|
||||
result.append(shader_wrapper)
|
||||
return result
|
||||
|
||||
def get_shader_data(self):
|
||||
return self.data
|
||||
def get_shader_data(self) -> Iterable[np.ndarray]:
|
||||
indices = self.get_shader_vert_indices()
|
||||
if indices is not None:
|
||||
return [self.data[indices]]
|
||||
else:
|
||||
return [self.data]
|
||||
|
||||
def get_uniforms(self):
|
||||
return self.uniforms
|
||||
|
||||
def get_shader_vert_indices(self):
|
||||
def get_shader_vert_indices(self) -> Optional[np.ndarray]:
|
||||
return self.shader_indices
|
||||
|
||||
def render(self, ctx: Context, camera_uniforms: dict):
|
||||
|
|
|
@ -191,6 +191,7 @@ class VMobject(Mobject):
|
|||
if background is not None:
|
||||
for mob in self.get_family(recurse):
|
||||
mob.stroke_behind = background
|
||||
mob.refresh_shader_wrapper_id()
|
||||
|
||||
if flat is not None:
|
||||
self.set_flat_stroke(flat)
|
||||
|
@ -1285,44 +1286,27 @@ class VMobject(Mobject):
|
|||
# For shaders
|
||||
|
||||
def init_shader_data(self, ctx: Context):
|
||||
self.shader_indices = np.zeros(0)
|
||||
self.shader_indices = None
|
||||
self.shader_wrapper = VShaderWrapper(
|
||||
ctx=ctx,
|
||||
vert_data=self.data,
|
||||
mobject_uniforms=self.uniforms,
|
||||
code_replacements=self.shader_code_replacements,
|
||||
stroke_behind=self.stroke_behind,
|
||||
depth_test=self.depth_test
|
||||
)
|
||||
|
||||
def get_shader_vert_indices(self):
|
||||
return None
|
||||
def refresh_shader_wrapper_id(self):
|
||||
for submob in self.get_family():
|
||||
if submob._shaders_initialized:
|
||||
submob.shader_wrapper.stroke_behind = submob.stroke_behind
|
||||
super().refresh_shader_wrapper_id()
|
||||
return self
|
||||
|
||||
def get_shader_data(self):
|
||||
# This should only come up when VMobjects appear together in a group
|
||||
return np.hstack([self.data, self.data[-1:]])
|
||||
|
||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
||||
if not self._shaders_initialized:
|
||||
self.init_shader_data(ctx)
|
||||
self._shaders_initialized = True
|
||||
|
||||
family = self.family_members_with_points()
|
||||
if not family:
|
||||
return []
|
||||
|
||||
stroke_behind = False
|
||||
for submob in family:
|
||||
# Maybe do this on set points instead? Or on noting changed data?
|
||||
submob.data["base_normal"][0::2] = submob.data["point"][0]
|
||||
if submob.stroke_behind:
|
||||
stroke_behind = True
|
||||
|
||||
self.shader_wrapper.read_in(
|
||||
list(it.chain(*([sm.data, sm.data[-1:]] for sm in family)))
|
||||
)
|
||||
rep = family[0] # Representative family member
|
||||
self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms())
|
||||
self.shader_wrapper.depth_test = rep.depth_test
|
||||
self.shader_wrapper.stroke_behind = stroke_behind
|
||||
return [self.shader_wrapper]
|
||||
def get_shader_data(self) -> Iterable[np.ndarray]:
|
||||
# Do we want this elsewhere? Say whenever points are refreshed or something?
|
||||
self.data["base_normal"][0::2] = self.data["point"][0]
|
||||
return [self.data, self.data[-1:]]
|
||||
|
||||
|
||||
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
|
||||
|
|
|
@ -388,7 +388,7 @@ class Scene(object):
|
|||
"""
|
||||
batches = batch_by_property(
|
||||
self.mobjects,
|
||||
lambda m: str(type(m)) + str(m.get_uniforms())
|
||||
lambda m: m.get_shader_wrapper(self.camera.ctx).get_id()
|
||||
)
|
||||
|
||||
for group in self.render_groups:
|
||||
|
|
|
@ -41,6 +41,7 @@ class ShaderWrapper(object):
|
|||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||
depth_test: bool = False,
|
||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||
code_replacements: dict[str, str] = dict(),
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.vert_data = vert_data
|
||||
|
@ -48,13 +49,15 @@ class ShaderWrapper(object):
|
|||
self.shader_folder = shader_folder
|
||||
self.depth_test = depth_test
|
||||
self.render_primitive = render_primitive
|
||||
self.texture_names_to_ids = dict()
|
||||
|
||||
self.program_uniform_mirror: UniformDict = dict()
|
||||
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
|
||||
|
||||
self.init_program_code()
|
||||
for old, new in code_replacements.items():
|
||||
self.replace_code(old, new)
|
||||
self.init_program()
|
||||
self.texture_names_to_ids = dict()
|
||||
if texture_paths is not None:
|
||||
self.init_textures(texture_paths)
|
||||
self.init_vertex_objects()
|
||||
|
@ -98,19 +101,14 @@ class ShaderWrapper(object):
|
|||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
def create_id(self) -> str:
|
||||
# A unique id for a shader
|
||||
program_id = hash("".join(map(str, self.program_code.values())))
|
||||
return "|".join(map(str, [
|
||||
program_id,
|
||||
def refresh_id(self) -> None:
|
||||
self.id = hash("".join(map(str, [
|
||||
"".join(map(str, self.program_code.values())),
|
||||
self.mobject_uniforms,
|
||||
self.depth_test,
|
||||
self.render_primitive,
|
||||
self.texture_names_to_ids,
|
||||
]))
|
||||
|
||||
def refresh_id(self) -> None:
|
||||
self.id = self.create_id()
|
||||
])))
|
||||
|
||||
def replace_code(self, old: str, new: str) -> None:
|
||||
code_map = self.program_code
|
||||
|
@ -213,10 +211,12 @@ class VShaderWrapper(ShaderWrapper):
|
|||
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
||||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||
depth_test: bool = False,
|
||||
# render_primitive: int = moderngl.TRIANGLES,
|
||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||
code_replacements: dict[str, str] = dict(),
|
||||
stroke_behind: bool = False,
|
||||
):
|
||||
self.stroke_behind = stroke_behind
|
||||
self.fill_canvas = get_fill_canvas(ctx)
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
vert_data=vert_data,
|
||||
|
@ -225,9 +225,8 @@ class VShaderWrapper(ShaderWrapper):
|
|||
texture_paths=texture_paths,
|
||||
depth_test=depth_test,
|
||||
render_primitive=render_primitive,
|
||||
code_replacements=code_replacements,
|
||||
)
|
||||
self.stroke_behind = stroke_behind
|
||||
self.fill_canvas = get_fill_canvas(self.ctx)
|
||||
|
||||
def init_program_code(self) -> None:
|
||||
self.program_code = {
|
||||
|
@ -312,6 +311,10 @@ class VShaderWrapper(ShaderWrapper):
|
|||
def set_backstroke(self, value: bool = True):
|
||||
self.stroke_behind = value
|
||||
|
||||
def refresh_id(self):
|
||||
super().refresh_id()
|
||||
self.id = hash(str(self.id) + str(self.stroke_behind))
|
||||
|
||||
# TODO, motidify read in to handle triangulation case for non-winding fill?
|
||||
|
||||
# Rendering
|
||||
|
|
Loading…
Add table
Reference in a new issue