From a92a5062242a2163bde79040b786ee129f9da068 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Wed, 28 Dec 2022 19:17:52 -0800 Subject: [PATCH] Don't distinguish stroke uniforms from fill uniforms --- manimlib/mobject/types/vectorized_mobject.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 62ccdf44..db2d9b50 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -122,6 +122,8 @@ class VMobject(Mobject): def init_uniforms(self): super().init_uniforms() self.uniforms["anti_alias_width"] = self.anti_alias_width + self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type] + self.uniforms["flat_stroke"] = float(self.flat_stroke) # These are here just to make type checkers happy def get_family(self, recurse: bool = True) -> list[VMobject]: @@ -396,19 +398,19 @@ class VMobject(Mobject): def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): for mob in self.get_family(recurse): - mob.flat_stroke = flat_stroke + mob.uniforms["flat_stroke"] = float(flat_stroke) return self def get_flat_stroke(self) -> bool: - return self.flat_stroke + return self.uniforms["flat_stroke"] == 1.0 def set_joint_type(self, joint_type: str, recurse: bool = True): for mob in self.get_family(recurse): - mob.joint_type = joint_type + mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type] return self - def get_joint_type(self) -> str: - return self.joint_type + def get_joint_type(self) -> float: + return self.uniforms["joint_type"] # Points def set_anchors_and_handles( @@ -1066,7 +1068,7 @@ class VMobject(Mobject): ) self.stroke_shader_wrapper = ShaderWrapper( vert_data=self.stroke_data, - uniforms=self.get_stroke_uniforms(), + uniforms=self.uniforms, shader_folder=self.stroke_shader_folder, render_primitive=self.render_primitive, ) @@ -1089,7 +1091,7 @@ class VMobject(Mobject): def get_stroke_shader_wrapper(self) -> ShaderWrapper: self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() - self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() + self.stroke_shader_wrapper.uniforms = self.get_shader_uniforms() self.stroke_shader_wrapper.depth_test = self.depth_test return self.stroke_shader_wrapper @@ -1116,13 +1118,13 @@ class VMobject(Mobject): ] for i, sw in enumerate(result): sw.depth_test = self.depth_test - return list(filter(lambda sw: len(sw.vert_data) > 0, result)) - def get_stroke_uniforms(self) -> dict[str, float]: result = dict(super().get_shader_uniforms()) result["joint_type"] = JOINT_TYPE_MAP[self.joint_type] result["flat_stroke"] = float(self.flat_stroke) return result + sw.uniforms = self.uniforms + return list(filter(lambda sw: len(sw.vert_data) > 0, self.shader_wrapper_list)) def get_stroke_shader_data(self) -> np.ndarray: points = self.get_points()