Have ShaderWrapper read in data rather than other shader wrappers

This commit is contained in:
Grant Sanderson 2023-01-15 20:27:19 -08:00
parent f63331eb24
commit ba9f61b50b
2 changed files with 52 additions and 35 deletions

View file

@ -1212,37 +1212,40 @@ class VMobject(Mobject):
def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
# Build up data lists
fill_sws = []
stroke_sws = []
bstroke_sws = []
fill_submobs = []
stroke_submobs = []
bstroke_submobs = []
for submob in self.family_members_with_points():
if submob.has_fill():
fill_sws.append(submob.get_fill_shader_wrapper())
fill_submobs.append(submob)
if submob.has_stroke():
lst = bstroke_sws if submob.draw_stroke_behind_fill else stroke_sws
lst.append(submob.get_stroke_shader_wrapper())
if submob.draw_stroke_behind_fill:
bstroke_submobs.append(submob)
else:
stroke_submobs.append(submob)
self_sws = [
fill_names = list(self.fill_data.dtype.names)
self.fill_shader_wrapper.read_in(
[sm.data[fill_names] for sm in fill_submobs],
[sm.get_fill_shader_vert_indices() for sm in fill_submobs],
)
self.stroke_shader_wrapper.read_in(
[sm.get_stroke_shader_data() for sm in stroke_submobs],
)
self.back_stroke_shader_wrapper.read_in(
[sm.get_stroke_shader_data() for sm in bstroke_submobs],
)
shader_wrappers = [
self.back_stroke_shader_wrapper,
self.fill_shader_wrapper,
self.stroke_shader_wrapper
]
sw_lists = [
bstroke_sws,
fill_sws,
stroke_sws
]
for sw, sw_list in zip(self_sws, sw_lists):
if not sw_list:
sw.vert_data = resize_array(sw.vert_data, 0)
continue
if sw is sw_list[0]:
sw.combine_with(*sw_list[1:])
else:
sw.read_in(*sw_list)
sw.depth_test = any(sw.depth_test for sw in sw_list)
sw.uniforms.update(sw_list[0].uniforms)
return [sw for sw in self_sws if len(sw.vert_data) > 0]
for sw in shader_wrappers:
# TODO, handle depth test and uniforms...
pass
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
def get_stroke_shader_data(self) -> np.ndarray:
# Set data array to be one longer than number of points,
@ -1255,6 +1258,7 @@ class VMobject(Mobject):
return self.stroke_data
self.get_joint_angles() # Recomputes, only if refresh is needed
for key in self.stroke_data.dtype.names:
self.stroke_data[key][:n] = self.data[key]
self.stroke_data[-1] = self.stroke_data[-2]

View file

@ -16,7 +16,7 @@ from manimlib.utils.iterables import resize_array
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from typing import Iterable, List
# Mobjects that should be rendered with
@ -136,25 +136,38 @@ class ShaderWrapper(object):
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
if len(shader_wrappers) > 0:
self.read_in(self.copy(), *shader_wrappers)
data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)]
if self.vert_indices is not None:
indices_list = [self.vert_indices, *(sw.vert_indices for sw in shader_wrappers)]
else:
indices_list = None
self.read_in(data_list, indices_list)
return self
def read_in(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
def read_in(
self,
vert_data_list: List[np.ndarray],
vert_indices_list: List[np.ndarray] | None = None
) -> ShaderWrapper:
# Assume all are of the same type
total_len = sum(len(sw.vert_data) for sw in shader_wrappers)
total_len = sum(len(data) for data in vert_data_list)
self.vert_data = resize_array(self.vert_data, total_len)
if self.vert_indices is not None:
total_verts = sum(len(sw.vert_indices) for sw in shader_wrappers)
if total_len == 0:
return self
if vert_indices_list is not None and self.vert_indices is not None:
total_verts = sum(len(vi) for vi in vert_indices_list)
self.vert_indices = resize_array(self.vert_indices, total_verts)
n_points = 0
n_verts = 0
for sw in shader_wrappers:
new_n_points = n_points + len(sw.vert_data)
self.vert_data[n_points:new_n_points] = sw.vert_data
if self.vert_indices is not None and sw.vert_indices is not None:
new_n_verts = n_verts + len(sw.vert_indices)
self.vert_indices[n_verts:new_n_verts] = sw.vert_indices + n_points
for k, data in enumerate(vert_data_list):
new_n_points = n_points + len(data)
self.vert_data[n_points:new_n_points] = data
if self.vert_indices is not None and vert_indices_list is not None:
vert_indices = vert_indices_list[k]
new_n_verts = n_verts + len(vert_indices)
self.vert_indices[n_verts:new_n_verts] = vert_indices + n_points
n_verts = new_n_verts
n_points = new_n_points
return self