diff --git a/mobject/coordinate_systems.py b/mobject/coordinate_systems.py index 8bc823ba..8115fc83 100644 --- a/mobject/coordinate_systems.py +++ b/mobject/coordinate_systems.py @@ -11,6 +11,7 @@ from mobject.number_line import NumberLine from mobject.svg.tex_mobject import TexMobject from mobject.types.vectorized_mobject import VGroup from mobject.types.vectorized_mobject import VMobject +from mobject.three_dimensions import ThreeDVMobject from utils.config_ops import digest_config from utils.space_ops import angle_of_vector @@ -29,14 +30,10 @@ class Axes(VGroup): "y_axis_config": { "label_direction": LEFT, }, - "z_axis_config": {}, "x_min": -FRAME_X_RADIUS, "x_max": FRAME_X_RADIUS, "y_min": -FRAME_Y_RADIUS, "y_max": FRAME_Y_RADIUS, - "z_min": -3.5, - "z_max": 3.5, - "z_normal": DOWN, "default_num_graph_points": 100, } @@ -46,15 +43,6 @@ class Axes(VGroup): self.y_axis = self.get_axis(self.y_min, self.y_max, self.y_axis_config) self.y_axis.rotate(np.pi / 2, about_point=ORIGIN) self.add(self.x_axis, self.y_axis) - if self.three_d: - self.z_axis = self.get_axis( - self.z_min, self.z_max, self.z_axis_config) - self.z_axis.rotate(-np.pi / 2, UP, about_point=ORIGIN) - self.z_axis.rotate( - angle_of_vector(self.z_normal), OUT, - about_point=ORIGIN - ) - self.add(self.z_axis) def get_axis(self, min_val, max_val, extra_config): config = dict(self.number_line_config) @@ -68,10 +56,10 @@ class Axes(VGroup): return x_axis_projection + y_axis_projection - origin def point_to_coords(self, point): - return ( - self.x_axis.point_to_number(point), - self.y_axis.point_to_number(point), - ) + return tuple([ + axis.point_to_number(point) + for axis in self + ]) def get_graph( self, function, num_graph_points=None, @@ -123,11 +111,56 @@ class ThreeDAxes(Axes): CONFIG = { "x_min": -5.5, "x_max": 5.5, - "y_min": -4.5, - "y_max": 4.5, - "three_d": True, + "y_min": -5.5, + "y_max": 5.5, + "z_axis_config": {}, + "z_min": -3.5, + "z_max": 3.5, + "z_normal": DOWN, + "num_axis_pieces": 20, + "light_source": 9 * DOWN + 7 * LEFT + 10 * OUT, } + def __init__(self, **kwargs): + Axes.__init__(self, **kwargs) + z_axis = self.z_axis = self.get_axis( + self.z_min, self.z_max, self.z_axis_config + ) + z_axis.rotate(-np.pi / 2, UP, about_point=ORIGIN) + z_axis.rotate( + angle_of_vector(self.z_normal), OUT, + about_point=ORIGIN + ) + self.add(z_axis) + + self.add_3d_pieces() + self.set_axis_shading() + + def add_3d_pieces(self): + for attr in "x_axis", "y_axis", "z_axis": + axis = getattr(self, attr) + axis.add(VGroup( + *axis.main_line.get_pieces(self.num_axis_pieces) + )) + axis.main_line.set_stroke(width=0, family=False) + axis_3d = ThreeDVMobject(axis) + self.remove(axis) + self.add(axis_3d) + setattr(self, attr, axis_3d) + + def set_axis_shading(self): + def make_func(axis): + vect = self.light_source + return lambda: ( + axis.get_edge_center(-vect), + axis.get_edge_center(vect), + ) + for axis in self: + for submob in axis.family_members_with_points(): + submob.get_gradient_start_and_end_points = make_func(axis) + submob.get_unit_normal = lambda a: np.ones(3) + submob.set_sheen(0.2) + class NumberPlane(VMobject): CONFIG = {