From 80c0e8813396309210f8503106fb29a46c0e43fd Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 9 Jan 2023 10:08:03 -0800 Subject: [PATCH] Put joint_angle information in VMobject.data --- manimlib/mobject/types/vectorized_mobject.py | 49 +++++++++----------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index ec3d3fca..74f59bc6 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -120,6 +120,7 @@ class VMobject(Mobject): "stroke_rgba": np.zeros((1, 4)), "stroke_width": np.zeros((1, 1)), "orientation": np.ones((1, 1)), + "joint_angle": np.zeros((1, 1)), }) def init_uniforms(self): @@ -1031,21 +1032,18 @@ class VMobject(Mobject): mob.needs_new_joint_angles = True return self - def get_joint_angles(self, result=None): - """ - One can optionally pass in `result` as an array - which the answer will be read into. - """ + def get_joint_angles(self): + if not self.needs_new_joint_angles: + return self.data["joint_angle"] + points = self.get_points() - if result is not None: - assert(len(result) == len(points)) - else: - result = np.zeros(len(points)) + self.data["joint_angle"] = resize_array(self.data["joint_angle"], len(points)) if(len(points) == 0): - return result + return self.data["joint_angle"] # Unit tangent vectors + self.get_anchors_and_handles() a0, h, a1 = points[0::3], points[1::3], points[2::3] a0_to_h = normalize_along_axis(h - a0, 1) h_to_a1 = normalize_along_axis(a1 - h, 1) @@ -1053,14 +1051,13 @@ class VMobject(Mobject): vect_to_vert = np.zeros(points.shape) vect_from_vert = np.zeros(points.shape) - vect_to_vert[1::3] = a0_to_h - vect_from_vert[1::3] = h_to_a1 - vect_to_vert[0] = h_to_a1[-1] - vect_to_vert[3::3] = h_to_a1[:-1] - vect_from_vert[0::3] = a0_to_h - + vect_to_vert[1::3] = a0_to_h vect_to_vert[2::3] = h_to_a1 + vect_to_vert[3::3] = h_to_a1[:-1] + + vect_from_vert[0::3] = a0_to_h + vect_from_vert[1::3] = h_to_a1 vect_from_vert[2:-1:3] = a0_to_h[1:] vect_from_vert[-1] = a0_to_h[0] @@ -1068,22 +1065,23 @@ class VMobject(Mobject): dots = (vect_to_vert * vect_from_vert).sum(1) angle = np.arccos(arr_clip(dots, -1, 1)) sgn = np.sign(cross2d(vect_to_vert, vect_from_vert)) - result[:] = sgn * angle + self.data["joint_angle"][:, 0] = sgn * angle # To communicate to the shader that a given anchor point # sits at the end of a curve, we set its angle equal # to something outside the range [-pi, pi]. # An arbitrary constant is used - mis_matches = (a0[1:] != a1[:-1]).any(1) ends_mismatch = (a1[-1] != a0[0]).any() if ends_mismatch: - result[0] = DISJOINT_CONST - result[-1] = DISJOINT_CONST - result[3::3][mis_matches] = DISJOINT_CONST - result[2:-1:3][mis_matches] = DISJOINT_CONST + self.data["joint_angle"][0] = DISJOINT_CONST + self.data["joint_angle"][-1] = DISJOINT_CONST + + mis_matches = (a0[1:] != a1[:-1]).any(1) + self.data["joint_angle"][3::3][mis_matches] = DISJOINT_CONST + self.data["joint_angle"][2:-1:3][mis_matches] = DISJOINT_CONST self.needs_new_joint_angles = False - return result + return self.data["joint_angle"] def triggers_refreshed_triangulation(func: Callable): @wraps(func) @@ -1199,9 +1197,8 @@ class VMobject(Mobject): self.read_data_to_shader(self.stroke_data, "point", "points") self.read_data_to_shader(self.stroke_data, "color", "stroke_rgba") self.read_data_to_shader(self.stroke_data, "stroke_width", "stroke_width") - - if self.needs_new_joint_angles: - self.get_joint_angles(self.stroke_data["joint_angle"][:, 0]) + self.get_joint_angles() # Potentially refreshes + self.read_data_to_shader(self.stroke_data, "joint_angle", "joint_angle") return self.stroke_data