Reconfigured how to initialize axes and number planes. Hopefully much more sensible this way

This commit is contained in:
Grant Sanderson 2020-06-05 19:24:35 -07:00
parent fbe917d461
commit 1a5fb207ae
2 changed files with 86 additions and 119 deletions

View file

@ -21,10 +21,10 @@ class CoordinateSystem():
""" """
CONFIG = { CONFIG = {
"dimension": 2, "dimension": 2,
"x_min": -FRAME_X_RADIUS, "x_range": [-8, 8, 1],
"x_max": FRAME_X_RADIUS, "y_range": [-4, 4, 1],
"y_min": -FRAME_Y_RADIUS, "width": None,
"y_max": FRAME_Y_RADIUS, "height": None,
} }
def coords_to_point(self, *coords): def coords_to_point(self, *coords):
@ -131,40 +131,40 @@ class CoordinateSystem():
class Axes(VGroup, CoordinateSystem): class Axes(VGroup, CoordinateSystem):
CONFIG = { CONFIG = {
"axis_config": { "axis_config": {
"color": LIGHT_GREY,
"include_tip": True, "include_tip": True,
"exclude_zero_from_default_numbers": True,
}, },
"x_axis_config": {}, "x_axis_config": {},
"y_axis_config": { "y_axis_config": {
"label_direction": LEFT, "line_to_number_direction": LEFT,
}, },
"center_point": ORIGIN,
} }
def __init__(self, **kwargs): def __init__(self, x_range=None, y_range=None, **kwargs):
VGroup.__init__(self, **kwargs) VGroup.__init__(self, **kwargs)
self.x_axis = self.create_axis( self.x_axis = self.create_axis(
self.x_min, self.x_max, self.x_axis_config x_range or self.x_range,
self.x_axis_config,
self.width,
) )
self.y_axis = self.create_axis( self.y_axis = self.create_axis(
self.y_min, self.y_max, self.y_axis_config y_range or self.y_range,
self.y_axis_config,
self.height
) )
self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN) self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN)
# Add as a separate group incase various other # Add as a separate group in case various other
# mobjects are added to self, as for example in # mobjects are added to self, as for example in
# NumberPlane below # NumberPlane below
self.axes = VGroup(self.x_axis, self.y_axis) self.axes = VGroup(self.x_axis, self.y_axis)
self.add(*self.axes) self.add(*self.axes)
self.shift(self.center_point) self.center()
def create_axis(self, min_val, max_val, axis_config): def create_axis(self, range_terms, axis_config, length):
new_config = merge_dicts_recursively( new_config = merge_dicts_recursively(self.axis_config, axis_config)
self.axis_config, new_config["width"] = length
{"x_min": min_val, "x_max": max_val}, axis = NumberLine(range_terms, **new_config)
axis_config, axis.shift(-axis.n2p(0))
) return axis
return NumberLine(**new_config)
def coords_to_point(self, *coords): def coords_to_point(self, *coords):
origin = self.x_axis.number_to_point(0) origin = self.x_axis.number_to_point(0)
@ -173,64 +173,54 @@ class Axes(VGroup, CoordinateSystem):
result += (axis.number_to_point(coord) - origin) result += (axis.number_to_point(coord) - origin)
return result return result
def c2p(self, *coords):
return self.coords_to_point(*coords)
def point_to_coords(self, point): def point_to_coords(self, point):
return tuple([ return tuple([
axis.point_to_number(point) axis.point_to_number(point)
for axis in self.get_axes() for axis in self.get_axes()
]) ])
def p2c(self, point):
return self.point_to_coords(point)
def get_axes(self): def get_axes(self):
return self.axes return self.axes
def get_coordinate_labels(self, x_vals=None, y_vals=None): def add_coordinate_labels(self, x_values=None, y_values=None):
if x_vals is None: axes = self.get_axes()
x_vals = [] self.coordinate_labels = VGroup()
if y_vals is None: for axis, values in zip(axes, [x_values, y_values]):
y_vals = [] numbers = axis.add_numbers(values, excluding=[0])
x_mobs = self.get_x_axis().get_number_mobjects(*x_vals) self.coordinate_labels.add(numbers)
y_mobs = self.get_y_axis().get_number_mobjects(*y_vals)
self.coordinate_labels = VGroup(x_mobs, y_mobs)
return self.coordinate_labels return self.coordinate_labels
def add_coordinates(self, x_vals=None, y_vals=None):
self.add(self.get_coordinate_labels(x_vals, y_vals))
return self
class ThreeDAxes(Axes): class ThreeDAxes(Axes):
CONFIG = { CONFIG = {
"dimension": 3, "dimension": 3,
"x_min": -5.5, "x_range": (-6, 6, 1),
"x_max": 5.5, "y_range": (-5, 5, 1),
"y_min": -5.5, "z_range": (-4, 4, 1),
"y_max": 5.5,
"z_axis_config": {}, "z_axis_config": {},
"z_min": -3.5,
"z_max": 3.5,
"z_normal": DOWN, "z_normal": DOWN,
"depth": None,
"num_axis_pieces": 20, "num_axis_pieces": 20,
"gloss": 0.5, "gloss": 0.5,
} }
def __init__(self, **kwargs): def __init__(self, x_range=None, y_range=None, z_range=None, **kwargs):
Axes.__init__(self, **kwargs) Axes.__init__(self, x_range, y_range, **kwargs)
z_axis = self.z_axis = self.create_axis(
self.z_min, self.z_max, self.z_axis_config z_axis = self.create_axis(
z_range or self.z_range,
self.z_axis_config,
self.depth,
) )
z_axis.rotate(-np.pi / 2, UP, about_point=ORIGIN) z_axis.rotate(-PI / 2, UP, about_point=ORIGIN)
z_axis.rotate( z_axis.rotate(
angle_of_vector(self.z_normal), OUT, angle_of_vector(self.z_normal), OUT,
about_point=ORIGIN about_point=ORIGIN
) )
z_axis.shift(self.x_axis.n2p(0))
self.axes.add(z_axis) self.axes.add(z_axis)
self.add(z_axis) self.add(z_axis)
self.z_axis = z_axis
for axis in self.axes: for axis in self.axes:
axis.insert_n_curves(self.num_axis_pieces - 1) axis.insert_n_curves(self.num_axis_pieces - 1)
@ -244,11 +234,13 @@ class NumberPlane(Axes):
"include_ticks": False, "include_ticks": False,
"include_tip": False, "include_tip": False,
"line_to_number_buff": SMALL_BUFF, "line_to_number_buff": SMALL_BUFF,
"label_direction": DR, "line_to_number_direction": DL,
"number_scale_val": 0.5, "decimal_number_config": {
"height": 0.2,
}
}, },
"y_axis_config": { "y_axis_config": {
"label_direction": DR, "line_to_number_direction": DL,
}, },
"background_line_style": { "background_line_style": {
"stroke_color": BLUE_D, "stroke_color": BLUE_D,
@ -257,8 +249,6 @@ class NumberPlane(Axes):
}, },
# Defaults to a faded version of line_config # Defaults to a faded version of line_config
"faded_line_style": None, "faded_line_style": None,
"x_line_frequency": 1,
"y_line_frequency": 1,
"faded_line_ratio": 1, "faded_line_ratio": 1,
"make_smooth_after_applying_functions": True, "make_smooth_after_applying_functions": True,
} }
@ -278,12 +268,8 @@ class NumberPlane(Axes):
self.faded_line_style = style self.faded_line_style = style
self.background_lines, self.faded_lines = self.get_lines() self.background_lines, self.faded_lines = self.get_lines()
self.background_lines.set_style( self.background_lines.set_style(**self.background_line_style)
**self.background_line_style, self.faded_lines.set_style(**self.faded_line_style)
)
self.faded_lines.set_style(
**self.faded_line_style,
)
self.add_to_back( self.add_to_back(
self.faded_lines, self.faded_lines,
self.background_lines, self.background_lines,
@ -292,45 +278,32 @@ class NumberPlane(Axes):
def get_lines(self): def get_lines(self):
x_axis = self.get_x_axis() x_axis = self.get_x_axis()
y_axis = self.get_y_axis() y_axis = self.get_y_axis()
x_freq = self.x_line_frequency
y_freq = self.y_line_frequency
x_lines1, x_lines2 = self.get_lines_parallel_to_axis( x_lines1, x_lines2 = self.get_lines_parallel_to_axis(x_axis, y_axis)
x_axis, y_axis, x_freq, y_lines1, y_lines2 = self.get_lines_parallel_to_axis(y_axis, x_axis)
self.faded_line_ratio,
)
y_lines1, y_lines2 = self.get_lines_parallel_to_axis(
y_axis, x_axis, y_freq,
self.faded_line_ratio,
)
lines1 = VGroup(*x_lines1, *y_lines1) lines1 = VGroup(*x_lines1, *y_lines1)
lines2 = VGroup(*x_lines2, *y_lines2) lines2 = VGroup(*x_lines2, *y_lines2)
return lines1, lines2 return lines1, lines2
def get_lines_parallel_to_axis(self, axis1, axis2, freq, ratio): def get_lines_parallel_to_axis(self, axis1, axis2):
freq = axis1.x_step
ratio = self.faded_line_ratio
line = Line(axis1.get_start(), axis1.get_end()) line = Line(axis1.get_start(), axis1.get_end())
dense_freq = (1 + ratio) dense_freq = (1 + ratio)
step = (1 / dense_freq) * freq step = (1 / dense_freq) * freq
lines1 = VGroup() lines1 = VGroup()
lines2 = VGroup() lines2 = VGroup()
ranges = ( inputs = np.arange(axis2.x_min, axis2.x_max + step, step)
np.arange(0, axis2.x_max, step), for i, x in enumerate(inputs):
np.arange(0, axis2.x_min, -step),
)
for inputs in ranges:
for k, x in enumerate(inputs):
new_line = line.copy() new_line = line.copy()
new_line.move_to(axis2.number_to_point(x)) new_line.shift(axis2.n2p(x) - axis2.n2p(0))
if k % (1 + ratio) == 0: if i % (1 + ratio) == 0:
lines1.add(new_line) lines1.add(new_line)
else: else:
lines2.add(new_line) lines2.add(new_line)
return lines1, lines2 return lines1, lines2
def get_center_point(self):
return self.coords_to_point(0, 0)
def get_x_unit_size(self): def get_x_unit_size(self):
return self.get_x_axis().get_unit_size() return self.get_x_axis().get_unit_size()
@ -342,19 +315,13 @@ class NumberPlane(Axes):
def get_vector(self, coords, **kwargs): def get_vector(self, coords, **kwargs):
kwargs["buff"] = 0 kwargs["buff"] = 0
return Arrow( return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs)
self.coords_to_point(0, 0),
self.coords_to_point(*coords),
**kwargs
)
def prepare_for_nonlinear_transform(self, num_inserted_curves=50): def prepare_for_nonlinear_transform(self, num_inserted_curves=50):
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
num_curves = mob.get_num_curves() num_curves = mob.get_num_curves()
if num_inserted_curves > num_curves: if num_inserted_curves > num_curves:
mob.insert_n_curves( mob.insert_n_curves(num_inserted_curves - num_curves)
num_inserted_curves - num_curves
)
return self return self
@ -379,15 +346,13 @@ class ComplexPlane(NumberPlane):
return self.point_to_number(point) return self.point_to_number(point)
def get_default_coordinate_values(self): def get_default_coordinate_values(self):
x_numbers = self.get_x_axis().default_numbers_to_display() x_numbers = self.get_x_axis().get_tick_range()
y_numbers = self.get_y_axis().default_numbers_to_display() y_numbers = self.get_y_axis().get_tick_range()
y_numbers = [ y_numbers = [complex(0, y) for y in y_numbers if y != 0]
complex(0, y) for y in y_numbers if y != 0
]
return [*x_numbers, *y_numbers] return [*x_numbers, *y_numbers]
def get_coordinate_labels(self, *numbers, **kwargs): def add_coordinate_labels(self, numbers=None, **kwargs):
if len(numbers) == 0: if numbers is None:
numbers = self.get_default_coordinate_values() numbers = self.get_default_coordinate_values()
self.coordinate_labels = VGroup() self.coordinate_labels = VGroup()
@ -405,8 +370,5 @@ class ComplexPlane(NumberPlane):
value = z.real value = z.real
number_mob = axis.get_number_mobject(value, **kwargs) number_mob = axis.get_number_mobject(value, **kwargs)
self.coordinate_labels.add(number_mob) self.coordinate_labels.add(number_mob)
return self.coordinate_labels self.add(self.coordinate_labels)
def add_coordinates(self, *numbers):
self.add(self.get_coordinate_labels(*numbers))
return self return self

View file

@ -5,6 +5,7 @@ from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.config_ops import merge_dicts_recursively
from manimlib.utils.iterables import list_difference_update
from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import fdiv
from manimlib.utils.space_ops import normalize from manimlib.utils.space_ops import normalize
@ -12,11 +13,12 @@ from manimlib.utils.space_ops import normalize
class NumberLine(Line): class NumberLine(Line):
CONFIG = { CONFIG = {
"color": LIGHT_GREY, "color": LIGHT_GREY,
"stroke_width": 3, "stroke_width": 2,
# List of 2 or 3 elements, x_min, x_max, step_size # List of 2 or 3 elements, x_min, x_max, step_size
"x_range": [-8, 8, 1], "x_range": [-8, 8, 1],
# How big is one one unit of this number line in terms of absolute spacial distance # How big is one one unit of this number line in terms of absolute spacial distance
"unit_size": 1, "unit_size": 1,
"width": None,
"include_ticks": True, "include_ticks": True,
"tick_size": 0.1, "tick_size": 0.1,
"longer_tick_multiple": 1.5, "longer_tick_multiple": 1.5,
@ -37,12 +39,12 @@ class NumberLine(Line):
def __init__(self, x_range=None, **kwargs): def __init__(self, x_range=None, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
if x_range is not None: if x_range is None:
self.x_range = x_range x_range = self.x_range
if len(self.x_range) == 2: if len(x_range) == 2:
self.x_range.append(1) x_range = [*x_range, 1]
x_min, x_max, x_step = self.x_range x_min, x_max, x_step = x_range
# A lot of old scenes pass in x_min or x_max explicitly, # A lot of old scenes pass in x_min or x_max explicitly,
# so this is just here to keep those workin # so this is just here to keep those workin
self.x_min = kwargs.get("x_min", x_min) self.x_min = kwargs.get("x_min", x_min)
@ -50,6 +52,9 @@ class NumberLine(Line):
self.x_step = kwargs.get("x_step", x_step) self.x_step = kwargs.get("x_step", x_step)
super().__init__(self.x_min * RIGHT, self.x_max * RIGHT, **kwargs) super().__init__(self.x_min * RIGHT, self.x_max * RIGHT, **kwargs)
if self.width:
self.set_width(self.width)
else:
self.scale(self.unit_size) self.scale(self.unit_size)
self.center() self.center()
@ -74,8 +79,6 @@ class NumberLine(Line):
if x in self.numbers_with_elongated_ticks: if x in self.numbers_with_elongated_ticks:
size *= self.longer_tick_multiple size *= self.longer_tick_multiple
ticks.add(self.get_tick(x, size)) ticks.add(self.get_tick(x, size))
if self.include_tip:
ticks.remove(ticks[-1])
self.add(ticks) self.add(ticks)
self.ticks = ticks self.ticks = ticks
@ -135,20 +138,22 @@ class NumberLine(Line):
direction=direction, direction=direction,
buff=buff buff=buff
) )
if x < 0: if x < 0 and self.line_to_number_direction[0] == 0:
# Align without the minus sign # Align without the minus sign
num_mob.shift(num_mob[0].get_width() * LEFT / 2) num_mob.shift(num_mob[0].get_width() * LEFT / 2)
return num_mob return num_mob
def add_numbers(self, *x_values, **kwargs): def add_numbers(self, x_values=None, excluding=None, **kwargs):
if len(x_values) == 0: if x_values is None:
x_values = self.get_tick_range() x_values = self.get_tick_range()
if excluding is not None:
x_values = list_difference_update(x_values, excluding)
self.numbers = VGroup() self.numbers = VGroup()
for x in x_values: for x in x_values:
self.numbers.add(self.get_number_mobject(x, **kwargs)) self.numbers.add(self.get_number_mobject(x, **kwargs))
self.add(self.numbers) self.add(self.numbers)
return self return self.numbers
class UnitInterval(NumberLine): class UnitInterval(NumberLine):