Kill CONFIG in coordinate_system

This commit is contained in:
Grant Sanderson 2022-12-15 16:19:03 -08:00
parent 57875875c1
commit 5b5b3a7d20

View file

@ -33,51 +33,53 @@ from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Type, TypeVar
from typing import Callable, Iterable, Sequence, Type, TypeVar, Tuple
from manimlib.mobject.mobject import Mobject
from manimlib.constants import ManimColor
from manimlib.constants import ManimColor, np_vector
RangeSpecifier = Tuple[float, float, float] | Tuple[float, float]
T = TypeVar("T", bound=Mobject)
EPSILON = 1e-8
DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
class CoordinateSystem(ABC):
"""
Abstract class for Axes and NumberPlane
"""
CONFIG = {
"dimension": 2,
"default_x_range": [-8.0, 8.0, 1.0],
"default_y_range": [-4.0, 4.0, 1.0],
"width": FRAME_WIDTH,
"height": FRAME_HEIGHT,
"num_sampled_graph_points_per_tick": 20,
}
dimension: int = 2
def __init__(self, **kwargs):
digest_config(self, kwargs)
self.x_range = np.array(self.default_x_range)
self.y_range = np.array(self.default_y_range)
def __init__(
self,
x_range: RangeSpecifier = DEFAULT_X_RANGE,
y_range: RangeSpecifier = DEFAULT_Y_RANGE,
num_sampled_graph_points_per_tick: int = 20,
):
self.x_range = x_range
self.y_range = y_range
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod
def coords_to_point(self, *coords: float) -> np.ndarray:
def coords_to_point(self, *coords: float) -> np_vector:
raise Exception("Not implemented")
@abstractmethod
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]:
def point_to_coords(self, point: np_vector) -> tuple[float, ...]:
raise Exception("Not implemented")
def c2p(self, *coords: float):
"""Abbreviation for coords_to_point"""
return self.coords_to_point(*coords)
def p2c(self, point: np.ndarray):
def p2c(self, point: np_vector):
"""Abbreviation for point_to_coords"""
return self.point_to_coords(point)
def get_origin(self) -> np.ndarray:
def get_origin(self) -> np_vector:
return self.c2p(*[0] * self.dimension)
@abstractmethod
@ -103,8 +105,8 @@ class CoordinateSystem(ABC):
def get_x_axis_label(
self,
label_tex: str,
edge: np.ndarray = RIGHT,
direction: np.ndarray = DL,
edge: np_vector = RIGHT,
direction: np_vector = DL,
**kwargs
) -> Tex:
return self.get_axis_label(
@ -115,8 +117,8 @@ class CoordinateSystem(ABC):
def get_y_axis_label(
self,
label_tex: str,
edge: np.ndarray = UP,
direction: np.ndarray = DR,
edge: np_vector = UP,
direction: np_vector = DR,
**kwargs
) -> Tex:
return self.get_axis_label(
@ -127,9 +129,9 @@ class CoordinateSystem(ABC):
def get_axis_label(
self,
label_tex: str,
axis: np.ndarray,
edge: np.ndarray,
direction: np.ndarray,
axis: np_vector,
edge: np_vector,
direction: np_vector,
buff: float = MED_SMALL_BUFF
) -> Tex:
label = Tex(label_tex)
@ -154,7 +156,7 @@ class CoordinateSystem(ABC):
def get_line_from_axis_to_point(
self,
index: int,
point: np.ndarray,
point: np_vector,
line_func: Type[T] = DashedLine,
color: ManimColor = GREY_A,
stroke_width: float = 2
@ -164,10 +166,10 @@ class CoordinateSystem(ABC):
line.set_stroke(color, stroke_width)
return line
def get_v_line(self, point: np.ndarray, **kwargs):
def get_v_line(self, point: np_vector, **kwargs):
return self.get_line_from_axis_to_point(0, point, **kwargs)
def get_h_line(self, point: np.ndarray, **kwargs):
def get_h_line(self, point: np_vector, **kwargs):
return self.get_line_from_axis_to_point(1, point, **kwargs)
# Useful for graphing
@ -197,7 +199,7 @@ class CoordinateSystem(ABC):
def get_parametric_curve(
self,
function: Callable[[float], np.ndarray],
function: Callable[[float], np_vector],
**kwargs
) -> ParametricCurve:
dim = self.dimension
@ -212,7 +214,7 @@ class CoordinateSystem(ABC):
self,
x: float,
graph: ParametricCurve
) -> np.ndarray | None:
) -> np_vector | None:
if hasattr(graph, "underlying_function"):
return self.coords_to_point(x, graph.underlying_function(x))
else:
@ -229,7 +231,7 @@ class CoordinateSystem(ABC):
else:
return None
def i2gp(self, x: float, graph: ParametricCurve) -> np.ndarray | None:
def i2gp(self, x: float, graph: ParametricCurve) -> np_vector | None:
"""
Alias for input_to_graph_point
"""
@ -264,7 +266,7 @@ class CoordinateSystem(ABC):
graph: ParametricCurve,
label: str | Mobject = "f(x)",
x: float | None = None,
direction: np.ndarray = RIGHT,
direction: np_vector = RIGHT,
buff: float = MED_SMALL_BUFF,
color: ManimColor | None = None
) -> Tex | Mobject:
@ -301,8 +303,8 @@ class CoordinateSystem(ABC):
return self.get_h_line(self.i2gp(x, graph), **kwargs)
def get_scatterplot(self,
x_values: np.ndarray,
y_values: np.ndarray,
x_values: np_vector,
y_values: np_vector,
**dot_config):
return DotCloud(self.c2p(x_values, y_values), **dot_config)
@ -398,38 +400,32 @@ class CoordinateSystem(ABC):
class Axes(VGroup, CoordinateSystem):
CONFIG = {
"axis_config": {
"include_tip": False,
"numbers_to_exclude": [0],
},
"x_axis_config": {},
"y_axis_config": {
"line_to_number_direction": LEFT,
},
"height": FRAME_HEIGHT - 2,
"width": FRAME_WIDTH - 2,
}
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
x_range: RangeSpecifier = DEFAULT_X_RANGE,
y_range: RangeSpecifier = DEFAULT_Y_RANGE,
axis_config: dict = dict(
include_tip=False,
numbers_to_exclude=[0],
),
x_axis_config: dict = dict(),
y_axis_config: dict = dict(line_to_number_direction=LEFT),
height: float = FRAME_HEIGHT - 2,
width: float = FRAME_WIDTH - 2,
**kwargs
):
CoordinateSystem.__init__(self, **kwargs)
CoordinateSystem.__init__(self, x_range, y_range, **kwargs)
VGroup.__init__(self, **kwargs)
if x_range is not None:
self.x_range[:len(x_range)] = x_range
if y_range is not None:
self.y_range[:len(y_range)] = y_range
self.x_axis = self.create_axis(
self.x_range, self.x_axis_config, self.width,
self.x_range,
axis_config=merge_dicts_recursively(axis_config, x_axis_config),
length=width,
)
self.y_axis = self.create_axis(
self.y_range, self.y_axis_config, self.height
self.y_range,
axis_config=merge_dicts_recursively(axis_config, y_axis_config),
length=height
)
self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN)
# Add as a separate group in case various other
@ -441,24 +437,22 @@ class Axes(VGroup, CoordinateSystem):
def create_axis(
self,
range_terms: Sequence[float],
range_terms: RangeSpecifier,
axis_config: dict[str],
length: float
) -> NumberLine:
new_config = merge_dicts_recursively(self.axis_config, axis_config)
new_config["width"] = length
axis = NumberLine(range_terms, **new_config)
axis = NumberLine(range_terms, width=length, **axis_config)
axis.shift(-axis.n2p(0))
return axis
def coords_to_point(self, *coords: float) -> np.ndarray:
def coords_to_point(self, *coords: float) -> np_vector:
origin = self.x_axis.number_to_point(0)
return origin + sum(
axis.number_to_point(coord) - origin
for axis, coord in zip(self.get_axes(), coords)
)
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]:
def point_to_coords(self, point: np_vector) -> tuple[float, ...]:
return tuple([
axis.point_to_number(point)
for axis in self.get_axes()
@ -485,48 +479,39 @@ class Axes(VGroup, CoordinateSystem):
class ThreeDAxes(Axes):
CONFIG = {
"dimension": 3,
"x_range": np.array([-6.0, 6.0, 1.0]),
"y_range": np.array([-5.0, 5.0, 1.0]),
"z_range": np.array([-4.0, 4.0, 1.0]),
"z_axis_config": {},
"z_normal": DOWN,
"height": None,
"width": None,
"depth": None,
"num_axis_pieces": 20,
"gloss": 0.5,
}
dimension: int = 3
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
z_range: Sequence[float] | None = None,
x_range: RangeSpecifier = (-6.0, 6.0, 1.0),
y_range: RangeSpecifier = (-5.0, 5.0, 1.0),
z_range: RangeSpecifier = (-4.0, 4.0, 1.0),
z_axis_config: dict = dict(),
z_normal: np_vector = DOWN,
depth: float = 6.0,
num_axis_pieces: int = 20,
gloss: float = 0.5,
**kwargs
):
Axes.__init__(self, x_range, y_range, **kwargs)
if z_range is not None:
self.z_range[:len(z_range)] = z_range
z_axis = self.create_axis(
self.z_range = z_range
self.z_axis = self.create_axis(
self.z_range,
self.z_axis_config,
self.depth,
axis_config=merge_dicts_recursively(kwargs.get("axes_config", {}), z_axis_config),
length=depth,
)
z_axis.rotate(-PI / 2, UP, about_point=ORIGIN)
z_axis.rotate(
angle_of_vector(self.z_normal), OUT,
self.z_axis.rotate(-PI / 2, UP, about_point=ORIGIN)
self.z_axis.rotate(
angle_of_vector(z_normal), OUT,
about_point=ORIGIN
)
z_axis.shift(self.x_axis.n2p(0))
self.axes.add(z_axis)
self.add(z_axis)
self.z_axis = z_axis
self.z_axis.shift(self.x_axis.n2p(0))
self.axes.add(self.z_axis)
self.add(self.z_axis)
for axis in self.axes:
axis.insert_n_curves(self.num_axis_pieces - 1)
axis.insert_n_curves(num_axis_pieces - 1)
def get_all_ranges(self) -> list[Sequence[float]]:
return [self.x_range, self.y_range, self.z_range]
@ -558,42 +543,50 @@ class ThreeDAxes(Axes):
class NumberPlane(Axes):
CONFIG = {
"axis_config": {
"stroke_color": WHITE,
"stroke_width": 2,
"include_ticks": False,
"include_tip": False,
"line_to_number_buff": SMALL_BUFF,
"line_to_number_direction": DL,
},
"y_axis_config": {
"line_to_number_direction": DL,
},
"background_line_style": {
"stroke_color": BLUE_D,
"stroke_width": 2,
"stroke_opacity": 1,
},
"height": None,
"width": None,
# Defaults to a faded version of line_config
"faded_line_style": None,
"faded_line_ratio": 4,
"make_smooth_after_applying_functions": True,
}
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
x_range: RangeSpecifier = (-8.0, 8.0, 1.0),
y_range: RangeSpecifier = (-4.0, 4.0, 1.0),
axis_config: dict = dict(
stroke_color=WHITE,
stroke_width=2,
include_ticks=False,
include_tip=False,
line_to_number_buff=SMALL_BUFF,
line_to_number_direction=DL,
),
y_axis_config: dict = dict(
line_to_number_direction=DL,
),
height: float = 8.0,
width: float = 16.0,
background_line_style: dict = dict(
stroke_color=BLUE_D,
stroke_width=2,
stroke_opacity=1,
),
# Defaults to a faded version of line_config
faded_line_style: dict = dict(),
faded_line_ratio: int = 4,
make_smooth_after_applying_functions: bool = True,
**kwargs
):
super().__init__(x_range, y_range, **kwargs)
super().__init__(
x_range, y_range,
height=height,
width=width,
axis_config=axis_config,
y_axis_config=y_axis_config,
**kwargs
)
self.background_line_style = background_line_style
self.faded_line_style = faded_line_style
self.faded_line_ratio = faded_line_ratio
self.make_smooth_after_applying_functions = make_smooth_after_applying_functions
self.init_background_lines()
def init_background_lines(self) -> None:
if self.faded_line_style is None:
if not self.faded_line_style:
style = dict(self.background_line_style)
# For anything numerical, like stroke_width
# and stroke_opacity, chop it in half
@ -666,23 +659,18 @@ class NumberPlane(Axes):
class ComplexPlane(NumberPlane):
CONFIG = {
"color": BLUE,
"line_frequency": 1,
}
def number_to_point(self, number: complex | float) -> np.ndarray:
def number_to_point(self, number: complex | float) -> np_vector:
number = complex(number)
return self.coords_to_point(number.real, number.imag)
def n2p(self, number: complex | float) -> np.ndarray:
def n2p(self, number: complex | float) -> np_vector:
return self.number_to_point(number)
def point_to_number(self, point: np.ndarray) -> complex:
def point_to_number(self, point: np_vector) -> complex:
x, y = self.point_to_coords(point)
return complex(x, y)
def p2n(self, point: np.ndarray) -> complex:
def p2n(self, point: np_vector) -> complex:
return self.point_to_number(point)
def get_default_coordinate_values(