mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Unify get_shader_wrapper_list, and and better subdivide render groups by ShaderWrapper ids
This commit is contained in:
parent
08e33faab8
commit
f12b143d16
4 changed files with 50 additions and 65 deletions
|
@ -1944,8 +1944,9 @@ class Mobject(object):
|
||||||
|
|
||||||
@affects_data
|
@affects_data
|
||||||
def replace_shader_code(self, old: str, new: str) -> Self:
|
def replace_shader_code(self, old: str, new: str) -> Self:
|
||||||
self.shader_code_replacements[old] = new
|
for mob in self.get_family():
|
||||||
self._shaders_initialized = False
|
mob.shader_code_replacements[old] = new
|
||||||
|
mob._shaders_initialized = False
|
||||||
for mob in self.get_ancestors():
|
for mob in self.get_ancestors():
|
||||||
mob._shaders_initialized = False
|
mob._shaders_initialized = False
|
||||||
return self
|
return self
|
||||||
|
@ -2000,48 +2001,45 @@ class Mobject(object):
|
||||||
texture_paths=self.texture_paths,
|
texture_paths=self.texture_paths,
|
||||||
depth_test=self.depth_test,
|
depth_test=self.depth_test,
|
||||||
render_primitive=self.render_primitive,
|
render_primitive=self.render_primitive,
|
||||||
|
code_replacements=self.shader_code_replacements,
|
||||||
)
|
)
|
||||||
|
|
||||||
def refresh_shader_wrapper_id(self):
|
def refresh_shader_wrapper_id(self):
|
||||||
if self._shaders_initialized:
|
for submob in self.get_family():
|
||||||
self.shader_wrapper.refresh_id()
|
if submob._shaders_initialized:
|
||||||
|
submob.shader_wrapper.depth_test = submob.depth_test
|
||||||
|
submob.shader_wrapper.refresh_id()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
|
def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
|
||||||
if not self._shaders_initialized:
|
if not self._shaders_initialized:
|
||||||
self.init_shader_data(ctx)
|
self.init_shader_data(ctx)
|
||||||
self._shaders_initialized = True
|
self._shaders_initialized = True
|
||||||
|
|
||||||
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
|
|
||||||
self.shader_wrapper.depth_test = self.depth_test
|
|
||||||
for old, new in self.shader_code_replacements.items():
|
|
||||||
self.shader_wrapper.replace_code(old, new)
|
|
||||||
return self.shader_wrapper
|
return self.shader_wrapper
|
||||||
|
|
||||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
||||||
family = self.family_members_with_points()
|
family = self.family_members_with_points()
|
||||||
for submob in family:
|
batches = batch_by_property(family, lambda sm: sm.get_shader_wrapper(ctx).get_id())
|
||||||
submob.get_shader_wrapper(ctx)
|
|
||||||
batches = batch_by_property(family, lambda submob: submob.shader_wrapper.get_id())
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for submobs, sid in batches:
|
for submobs, sid in batches:
|
||||||
shader_wrapper = submobs[0].shader_wrapper
|
shader_wrapper = submobs[0].shader_wrapper
|
||||||
data_list = [sm.get_shader_data() for sm in submobs]
|
data_list = list(it.chain(*(sm.get_shader_data() for sm in submobs)))
|
||||||
indices_list = [sm.get_shader_vert_indices() for sm in submobs]
|
shader_wrapper.read_in(data_list, indices_list=None)
|
||||||
if indices_list[0] is None:
|
|
||||||
indices_list = None
|
|
||||||
shader_wrapper.read_in(data_list, indices_list)
|
|
||||||
result.append(shader_wrapper)
|
result.append(shader_wrapper)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_shader_data(self):
|
def get_shader_data(self) -> Iterable[np.ndarray]:
|
||||||
return self.data
|
indices = self.get_shader_vert_indices()
|
||||||
|
if indices is not None:
|
||||||
|
return [self.data[indices]]
|
||||||
|
else:
|
||||||
|
return [self.data]
|
||||||
|
|
||||||
def get_uniforms(self):
|
def get_uniforms(self):
|
||||||
return self.uniforms
|
return self.uniforms
|
||||||
|
|
||||||
def get_shader_vert_indices(self):
|
def get_shader_vert_indices(self) -> Optional[np.ndarray]:
|
||||||
return self.shader_indices
|
return self.shader_indices
|
||||||
|
|
||||||
def render(self, ctx: Context, camera_uniforms: dict):
|
def render(self, ctx: Context, camera_uniforms: dict):
|
||||||
|
|
|
@ -191,6 +191,7 @@ class VMobject(Mobject):
|
||||||
if background is not None:
|
if background is not None:
|
||||||
for mob in self.get_family(recurse):
|
for mob in self.get_family(recurse):
|
||||||
mob.stroke_behind = background
|
mob.stroke_behind = background
|
||||||
|
mob.refresh_shader_wrapper_id()
|
||||||
|
|
||||||
if flat is not None:
|
if flat is not None:
|
||||||
self.set_flat_stroke(flat)
|
self.set_flat_stroke(flat)
|
||||||
|
@ -1285,44 +1286,27 @@ class VMobject(Mobject):
|
||||||
# For shaders
|
# For shaders
|
||||||
|
|
||||||
def init_shader_data(self, ctx: Context):
|
def init_shader_data(self, ctx: Context):
|
||||||
self.shader_indices = np.zeros(0)
|
self.shader_indices = None
|
||||||
self.shader_wrapper = VShaderWrapper(
|
self.shader_wrapper = VShaderWrapper(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
vert_data=self.data,
|
vert_data=self.data,
|
||||||
mobject_uniforms=self.uniforms,
|
mobject_uniforms=self.uniforms,
|
||||||
|
code_replacements=self.shader_code_replacements,
|
||||||
|
stroke_behind=self.stroke_behind,
|
||||||
|
depth_test=self.depth_test
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_shader_vert_indices(self):
|
def refresh_shader_wrapper_id(self):
|
||||||
return None
|
for submob in self.get_family():
|
||||||
|
if submob._shaders_initialized:
|
||||||
|
submob.shader_wrapper.stroke_behind = submob.stroke_behind
|
||||||
|
super().refresh_shader_wrapper_id()
|
||||||
|
return self
|
||||||
|
|
||||||
def get_shader_data(self):
|
def get_shader_data(self) -> Iterable[np.ndarray]:
|
||||||
# This should only come up when VMobjects appear together in a group
|
# Do we want this elsewhere? Say whenever points are refreshed or something?
|
||||||
return np.hstack([self.data, self.data[-1:]])
|
self.data["base_normal"][0::2] = self.data["point"][0]
|
||||||
|
return [self.data, self.data[-1:]]
|
||||||
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
|
|
||||||
if not self._shaders_initialized:
|
|
||||||
self.init_shader_data(ctx)
|
|
||||||
self._shaders_initialized = True
|
|
||||||
|
|
||||||
family = self.family_members_with_points()
|
|
||||||
if not family:
|
|
||||||
return []
|
|
||||||
|
|
||||||
stroke_behind = False
|
|
||||||
for submob in family:
|
|
||||||
# Maybe do this on set points instead? Or on noting changed data?
|
|
||||||
submob.data["base_normal"][0::2] = submob.data["point"][0]
|
|
||||||
if submob.stroke_behind:
|
|
||||||
stroke_behind = True
|
|
||||||
|
|
||||||
self.shader_wrapper.read_in(
|
|
||||||
list(it.chain(*([sm.data, sm.data[-1:]] for sm in family)))
|
|
||||||
)
|
|
||||||
rep = family[0] # Representative family member
|
|
||||||
self.shader_wrapper.bind_to_mobject_uniforms(rep.get_uniforms())
|
|
||||||
self.shader_wrapper.depth_test = rep.depth_test
|
|
||||||
self.shader_wrapper.stroke_behind = stroke_behind
|
|
||||||
return [self.shader_wrapper]
|
|
||||||
|
|
||||||
|
|
||||||
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
|
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
|
||||||
|
|
|
@ -388,7 +388,7 @@ class Scene(object):
|
||||||
"""
|
"""
|
||||||
batches = batch_by_property(
|
batches = batch_by_property(
|
||||||
self.mobjects,
|
self.mobjects,
|
||||||
lambda m: str(type(m)) + str(m.get_uniforms())
|
lambda m: m.get_shader_wrapper(self.camera.ctx).get_id()
|
||||||
)
|
)
|
||||||
|
|
||||||
for group in self.render_groups:
|
for group in self.render_groups:
|
||||||
|
|
|
@ -41,6 +41,7 @@ class ShaderWrapper(object):
|
||||||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||||
depth_test: bool = False,
|
depth_test: bool = False,
|
||||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||||
|
code_replacements: dict[str, str] = dict(),
|
||||||
):
|
):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.vert_data = vert_data
|
self.vert_data = vert_data
|
||||||
|
@ -48,13 +49,15 @@ class ShaderWrapper(object):
|
||||||
self.shader_folder = shader_folder
|
self.shader_folder = shader_folder
|
||||||
self.depth_test = depth_test
|
self.depth_test = depth_test
|
||||||
self.render_primitive = render_primitive
|
self.render_primitive = render_primitive
|
||||||
|
self.texture_names_to_ids = dict()
|
||||||
|
|
||||||
self.program_uniform_mirror: UniformDict = dict()
|
self.program_uniform_mirror: UniformDict = dict()
|
||||||
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
|
self.bind_to_mobject_uniforms(mobject_uniforms or dict())
|
||||||
|
|
||||||
self.init_program_code()
|
self.init_program_code()
|
||||||
|
for old, new in code_replacements.items():
|
||||||
|
self.replace_code(old, new)
|
||||||
self.init_program()
|
self.init_program()
|
||||||
self.texture_names_to_ids = dict()
|
|
||||||
if texture_paths is not None:
|
if texture_paths is not None:
|
||||||
self.init_textures(texture_paths)
|
self.init_textures(texture_paths)
|
||||||
self.init_vertex_objects()
|
self.init_vertex_objects()
|
||||||
|
@ -98,19 +101,14 @@ class ShaderWrapper(object):
|
||||||
def get_id(self) -> str:
|
def get_id(self) -> str:
|
||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
def create_id(self) -> str:
|
def refresh_id(self) -> None:
|
||||||
# A unique id for a shader
|
self.id = hash("".join(map(str, [
|
||||||
program_id = hash("".join(map(str, self.program_code.values())))
|
"".join(map(str, self.program_code.values())),
|
||||||
return "|".join(map(str, [
|
|
||||||
program_id,
|
|
||||||
self.mobject_uniforms,
|
self.mobject_uniforms,
|
||||||
self.depth_test,
|
self.depth_test,
|
||||||
self.render_primitive,
|
self.render_primitive,
|
||||||
self.texture_names_to_ids,
|
self.texture_names_to_ids,
|
||||||
]))
|
])))
|
||||||
|
|
||||||
def refresh_id(self) -> None:
|
|
||||||
self.id = self.create_id()
|
|
||||||
|
|
||||||
def replace_code(self, old: str, new: str) -> None:
|
def replace_code(self, old: str, new: str) -> None:
|
||||||
code_map = self.program_code
|
code_map = self.program_code
|
||||||
|
@ -213,10 +211,12 @@ class VShaderWrapper(ShaderWrapper):
|
||||||
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
mobject_uniforms: Optional[UniformDict] = None, # A dictionary mapping names of uniform variables
|
||||||
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
|
||||||
depth_test: bool = False,
|
depth_test: bool = False,
|
||||||
# render_primitive: int = moderngl.TRIANGLES,
|
|
||||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||||
|
code_replacements: dict[str, str] = dict(),
|
||||||
stroke_behind: bool = False,
|
stroke_behind: bool = False,
|
||||||
):
|
):
|
||||||
|
self.stroke_behind = stroke_behind
|
||||||
|
self.fill_canvas = get_fill_canvas(ctx)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
vert_data=vert_data,
|
vert_data=vert_data,
|
||||||
|
@ -225,9 +225,8 @@ class VShaderWrapper(ShaderWrapper):
|
||||||
texture_paths=texture_paths,
|
texture_paths=texture_paths,
|
||||||
depth_test=depth_test,
|
depth_test=depth_test,
|
||||||
render_primitive=render_primitive,
|
render_primitive=render_primitive,
|
||||||
|
code_replacements=code_replacements,
|
||||||
)
|
)
|
||||||
self.stroke_behind = stroke_behind
|
|
||||||
self.fill_canvas = get_fill_canvas(self.ctx)
|
|
||||||
|
|
||||||
def init_program_code(self) -> None:
|
def init_program_code(self) -> None:
|
||||||
self.program_code = {
|
self.program_code = {
|
||||||
|
@ -312,6 +311,10 @@ class VShaderWrapper(ShaderWrapper):
|
||||||
def set_backstroke(self, value: bool = True):
|
def set_backstroke(self, value: bool = True):
|
||||||
self.stroke_behind = value
|
self.stroke_behind = value
|
||||||
|
|
||||||
|
def refresh_id(self):
|
||||||
|
super().refresh_id()
|
||||||
|
self.id = hash(str(self.id) + str(self.stroke_behind))
|
||||||
|
|
||||||
# TODO, motidify read in to handle triangulation case for non-winding fill?
|
# TODO, motidify read in to handle triangulation case for non-winding fill?
|
||||||
|
|
||||||
# Rendering
|
# Rendering
|
||||||
|
|
Loading…
Add table
Reference in a new issue