diff --git a/camera/three_d_camera.py b/camera/three_d_camera.py index 87eacac2..a4f4fe3d 100644 --- a/camera/three_d_camera.py +++ b/camera/three_d_camera.py @@ -6,6 +6,7 @@ from constants import * from camera.camera import Camera from mobject.types.point_cloud_mobject import Point +from mobject.types.vectorized_mobject import VMobject from mobject.three_d_utils import get_3d_vmob_start_corner from mobject.three_d_utils import get_3d_vmob_start_corner_unit_normal from mobject.three_d_utils import get_3d_vmob_end_corner @@ -94,22 +95,27 @@ class ThreeDCamera(Camera): vmobject, vmobject.get_fill_rgbas() ) - def display_multiple_vectorized_mobjects(self, vmobjects, pixel_array): + def get_mobjects_to_display(self, *args, **kwargs): + mobjects = Camera.get_mobjects_to_display( + self, *args, **kwargs + ) rot_matrix = self.get_rotation_matrix() - def z_key(vmob): + def z_key(mob): + if not isinstance(mob, VMobject): + return np.inf + if not mob.shade_in_3d: + return np.inf # Assign a number to a three dimensional mobjects # based on how close it is to the camera - if vmob.shade_in_3d: - return np.dot( - vmob.get_center(), - rot_matrix.T - )[2] - else: - return np.inf - Camera.display_multiple_vectorized_mobjects( - self, sorted(vmobjects, key=z_key), pixel_array - ) + points = mob.points + if len(points) == 0: + return 0 + return np.dot( + center_of_mass(points), + rot_matrix.T + )[2] + return sorted(mobjects, key=z_key) def get_phi(self): return self.phi_tracker.get_value()