Kill CONFIG in coordinate_system

This commit is contained in:
Grant Sanderson 2022-12-15 16:19:03 -08:00
parent 57875875c1
commit 5b5b3a7d20

View file

@ -33,51 +33,53 @@ from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Type, TypeVar from typing import Callable, Iterable, Sequence, Type, TypeVar, Tuple
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.constants import ManimColor from manimlib.constants import ManimColor, np_vector
RangeSpecifier = Tuple[float, float, float] | Tuple[float, float]
T = TypeVar("T", bound=Mobject) T = TypeVar("T", bound=Mobject)
EPSILON = 1e-8 EPSILON = 1e-8
DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
class CoordinateSystem(ABC): class CoordinateSystem(ABC):
""" """
Abstract class for Axes and NumberPlane Abstract class for Axes and NumberPlane
""" """
CONFIG = { dimension: int = 2
"dimension": 2,
"default_x_range": [-8.0, 8.0, 1.0],
"default_y_range": [-4.0, 4.0, 1.0],
"width": FRAME_WIDTH,
"height": FRAME_HEIGHT,
"num_sampled_graph_points_per_tick": 20,
}
def __init__(self, **kwargs): def __init__(
digest_config(self, kwargs) self,
self.x_range = np.array(self.default_x_range) x_range: RangeSpecifier = DEFAULT_X_RANGE,
self.y_range = np.array(self.default_y_range) y_range: RangeSpecifier = DEFAULT_Y_RANGE,
num_sampled_graph_points_per_tick: int = 20,
):
self.x_range = x_range
self.y_range = y_range
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod @abstractmethod
def coords_to_point(self, *coords: float) -> np.ndarray: def coords_to_point(self, *coords: float) -> np_vector:
raise Exception("Not implemented") raise Exception("Not implemented")
@abstractmethod @abstractmethod
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]: def point_to_coords(self, point: np_vector) -> tuple[float, ...]:
raise Exception("Not implemented") raise Exception("Not implemented")
def c2p(self, *coords: float): def c2p(self, *coords: float):
"""Abbreviation for coords_to_point""" """Abbreviation for coords_to_point"""
return self.coords_to_point(*coords) return self.coords_to_point(*coords)
def p2c(self, point: np.ndarray): def p2c(self, point: np_vector):
"""Abbreviation for point_to_coords""" """Abbreviation for point_to_coords"""
return self.point_to_coords(point) return self.point_to_coords(point)
def get_origin(self) -> np.ndarray: def get_origin(self) -> np_vector:
return self.c2p(*[0] * self.dimension) return self.c2p(*[0] * self.dimension)
@abstractmethod @abstractmethod
@ -103,8 +105,8 @@ class CoordinateSystem(ABC):
def get_x_axis_label( def get_x_axis_label(
self, self,
label_tex: str, label_tex: str,
edge: np.ndarray = RIGHT, edge: np_vector = RIGHT,
direction: np.ndarray = DL, direction: np_vector = DL,
**kwargs **kwargs
) -> Tex: ) -> Tex:
return self.get_axis_label( return self.get_axis_label(
@ -115,8 +117,8 @@ class CoordinateSystem(ABC):
def get_y_axis_label( def get_y_axis_label(
self, self,
label_tex: str, label_tex: str,
edge: np.ndarray = UP, edge: np_vector = UP,
direction: np.ndarray = DR, direction: np_vector = DR,
**kwargs **kwargs
) -> Tex: ) -> Tex:
return self.get_axis_label( return self.get_axis_label(
@ -127,9 +129,9 @@ class CoordinateSystem(ABC):
def get_axis_label( def get_axis_label(
self, self,
label_tex: str, label_tex: str,
axis: np.ndarray, axis: np_vector,
edge: np.ndarray, edge: np_vector,
direction: np.ndarray, direction: np_vector,
buff: float = MED_SMALL_BUFF buff: float = MED_SMALL_BUFF
) -> Tex: ) -> Tex:
label = Tex(label_tex) label = Tex(label_tex)
@ -154,7 +156,7 @@ class CoordinateSystem(ABC):
def get_line_from_axis_to_point( def get_line_from_axis_to_point(
self, self,
index: int, index: int,
point: np.ndarray, point: np_vector,
line_func: Type[T] = DashedLine, line_func: Type[T] = DashedLine,
color: ManimColor = GREY_A, color: ManimColor = GREY_A,
stroke_width: float = 2 stroke_width: float = 2
@ -164,10 +166,10 @@ class CoordinateSystem(ABC):
line.set_stroke(color, stroke_width) line.set_stroke(color, stroke_width)
return line return line
def get_v_line(self, point: np.ndarray, **kwargs): def get_v_line(self, point: np_vector, **kwargs):
return self.get_line_from_axis_to_point(0, point, **kwargs) return self.get_line_from_axis_to_point(0, point, **kwargs)
def get_h_line(self, point: np.ndarray, **kwargs): def get_h_line(self, point: np_vector, **kwargs):
return self.get_line_from_axis_to_point(1, point, **kwargs) return self.get_line_from_axis_to_point(1, point, **kwargs)
# Useful for graphing # Useful for graphing
@ -197,7 +199,7 @@ class CoordinateSystem(ABC):
def get_parametric_curve( def get_parametric_curve(
self, self,
function: Callable[[float], np.ndarray], function: Callable[[float], np_vector],
**kwargs **kwargs
) -> ParametricCurve: ) -> ParametricCurve:
dim = self.dimension dim = self.dimension
@ -212,7 +214,7 @@ class CoordinateSystem(ABC):
self, self,
x: float, x: float,
graph: ParametricCurve graph: ParametricCurve
) -> np.ndarray | None: ) -> np_vector | None:
if hasattr(graph, "underlying_function"): if hasattr(graph, "underlying_function"):
return self.coords_to_point(x, graph.underlying_function(x)) return self.coords_to_point(x, graph.underlying_function(x))
else: else:
@ -229,7 +231,7 @@ class CoordinateSystem(ABC):
else: else:
return None return None
def i2gp(self, x: float, graph: ParametricCurve) -> np.ndarray | None: def i2gp(self, x: float, graph: ParametricCurve) -> np_vector | None:
""" """
Alias for input_to_graph_point Alias for input_to_graph_point
""" """
@ -264,7 +266,7 @@ class CoordinateSystem(ABC):
graph: ParametricCurve, graph: ParametricCurve,
label: str | Mobject = "f(x)", label: str | Mobject = "f(x)",
x: float | None = None, x: float | None = None,
direction: np.ndarray = RIGHT, direction: np_vector = RIGHT,
buff: float = MED_SMALL_BUFF, buff: float = MED_SMALL_BUFF,
color: ManimColor | None = None color: ManimColor | None = None
) -> Tex | Mobject: ) -> Tex | Mobject:
@ -301,8 +303,8 @@ class CoordinateSystem(ABC):
return self.get_h_line(self.i2gp(x, graph), **kwargs) return self.get_h_line(self.i2gp(x, graph), **kwargs)
def get_scatterplot(self, def get_scatterplot(self,
x_values: np.ndarray, x_values: np_vector,
y_values: np.ndarray, y_values: np_vector,
**dot_config): **dot_config):
return DotCloud(self.c2p(x_values, y_values), **dot_config) return DotCloud(self.c2p(x_values, y_values), **dot_config)
@ -398,38 +400,32 @@ class CoordinateSystem(ABC):
class Axes(VGroup, CoordinateSystem): class Axes(VGroup, CoordinateSystem):
CONFIG = {
"axis_config": {
"include_tip": False,
"numbers_to_exclude": [0],
},
"x_axis_config": {},
"y_axis_config": {
"line_to_number_direction": LEFT,
},
"height": FRAME_HEIGHT - 2,
"width": FRAME_WIDTH - 2,
}
def __init__( def __init__(
self, self,
x_range: Sequence[float] | None = None, x_range: RangeSpecifier = DEFAULT_X_RANGE,
y_range: Sequence[float] | None = None, y_range: RangeSpecifier = DEFAULT_Y_RANGE,
axis_config: dict = dict(
include_tip=False,
numbers_to_exclude=[0],
),
x_axis_config: dict = dict(),
y_axis_config: dict = dict(line_to_number_direction=LEFT),
height: float = FRAME_HEIGHT - 2,
width: float = FRAME_WIDTH - 2,
**kwargs **kwargs
): ):
CoordinateSystem.__init__(self, **kwargs) CoordinateSystem.__init__(self, x_range, y_range, **kwargs)
VGroup.__init__(self, **kwargs) VGroup.__init__(self, **kwargs)
if x_range is not None:
self.x_range[:len(x_range)] = x_range
if y_range is not None:
self.y_range[:len(y_range)] = y_range
self.x_axis = self.create_axis( self.x_axis = self.create_axis(
self.x_range, self.x_axis_config, self.width, self.x_range,
axis_config=merge_dicts_recursively(axis_config, x_axis_config),
length=width,
) )
self.y_axis = self.create_axis( self.y_axis = self.create_axis(
self.y_range, self.y_axis_config, self.height self.y_range,
axis_config=merge_dicts_recursively(axis_config, y_axis_config),
length=height
) )
self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN) self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN)
# Add as a separate group in case various other # Add as a separate group in case various other
@ -441,24 +437,22 @@ class Axes(VGroup, CoordinateSystem):
def create_axis( def create_axis(
self, self,
range_terms: Sequence[float], range_terms: RangeSpecifier,
axis_config: dict[str], axis_config: dict[str],
length: float length: float
) -> NumberLine: ) -> NumberLine:
new_config = merge_dicts_recursively(self.axis_config, axis_config) axis = NumberLine(range_terms, width=length, **axis_config)
new_config["width"] = length
axis = NumberLine(range_terms, **new_config)
axis.shift(-axis.n2p(0)) axis.shift(-axis.n2p(0))
return axis return axis
def coords_to_point(self, *coords: float) -> np.ndarray: def coords_to_point(self, *coords: float) -> np_vector:
origin = self.x_axis.number_to_point(0) origin = self.x_axis.number_to_point(0)
return origin + sum( return origin + sum(
axis.number_to_point(coord) - origin axis.number_to_point(coord) - origin
for axis, coord in zip(self.get_axes(), coords) for axis, coord in zip(self.get_axes(), coords)
) )
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]: def point_to_coords(self, point: np_vector) -> tuple[float, ...]:
return tuple([ return tuple([
axis.point_to_number(point) axis.point_to_number(point)
for axis in self.get_axes() for axis in self.get_axes()
@ -485,48 +479,39 @@ class Axes(VGroup, CoordinateSystem):
class ThreeDAxes(Axes): class ThreeDAxes(Axes):
CONFIG = { dimension: int = 3
"dimension": 3,
"x_range": np.array([-6.0, 6.0, 1.0]),
"y_range": np.array([-5.0, 5.0, 1.0]),
"z_range": np.array([-4.0, 4.0, 1.0]),
"z_axis_config": {},
"z_normal": DOWN,
"height": None,
"width": None,
"depth": None,
"num_axis_pieces": 20,
"gloss": 0.5,
}
def __init__( def __init__(
self, self,
x_range: Sequence[float] | None = None, x_range: RangeSpecifier = (-6.0, 6.0, 1.0),
y_range: Sequence[float] | None = None, y_range: RangeSpecifier = (-5.0, 5.0, 1.0),
z_range: Sequence[float] | None = None, z_range: RangeSpecifier = (-4.0, 4.0, 1.0),
z_axis_config: dict = dict(),
z_normal: np_vector = DOWN,
depth: float = 6.0,
num_axis_pieces: int = 20,
gloss: float = 0.5,
**kwargs **kwargs
): ):
Axes.__init__(self, x_range, y_range, **kwargs) Axes.__init__(self, x_range, y_range, **kwargs)
if z_range is not None:
self.z_range[:len(z_range)] = z_range
z_axis = self.create_axis( self.z_range = z_range
self.z_axis = self.create_axis(
self.z_range, self.z_range,
self.z_axis_config, axis_config=merge_dicts_recursively(kwargs.get("axes_config", {}), z_axis_config),
self.depth, length=depth,
) )
z_axis.rotate(-PI / 2, UP, about_point=ORIGIN) self.z_axis.rotate(-PI / 2, UP, about_point=ORIGIN)
z_axis.rotate( self.z_axis.rotate(
angle_of_vector(self.z_normal), OUT, angle_of_vector(z_normal), OUT,
about_point=ORIGIN about_point=ORIGIN
) )
z_axis.shift(self.x_axis.n2p(0)) self.z_axis.shift(self.x_axis.n2p(0))
self.axes.add(z_axis) self.axes.add(self.z_axis)
self.add(z_axis) self.add(self.z_axis)
self.z_axis = z_axis
for axis in self.axes: for axis in self.axes:
axis.insert_n_curves(self.num_axis_pieces - 1) axis.insert_n_curves(num_axis_pieces - 1)
def get_all_ranges(self) -> list[Sequence[float]]: def get_all_ranges(self) -> list[Sequence[float]]:
return [self.x_range, self.y_range, self.z_range] return [self.x_range, self.y_range, self.z_range]
@ -558,42 +543,50 @@ class ThreeDAxes(Axes):
class NumberPlane(Axes): class NumberPlane(Axes):
CONFIG = {
"axis_config": {
"stroke_color": WHITE,
"stroke_width": 2,
"include_ticks": False,
"include_tip": False,
"line_to_number_buff": SMALL_BUFF,
"line_to_number_direction": DL,
},
"y_axis_config": {
"line_to_number_direction": DL,
},
"background_line_style": {
"stroke_color": BLUE_D,
"stroke_width": 2,
"stroke_opacity": 1,
},
"height": None,
"width": None,
# Defaults to a faded version of line_config
"faded_line_style": None,
"faded_line_ratio": 4,
"make_smooth_after_applying_functions": True,
}
def __init__( def __init__(
self, self,
x_range: Sequence[float] | None = None, x_range: RangeSpecifier = (-8.0, 8.0, 1.0),
y_range: Sequence[float] | None = None, y_range: RangeSpecifier = (-4.0, 4.0, 1.0),
axis_config: dict = dict(
stroke_color=WHITE,
stroke_width=2,
include_ticks=False,
include_tip=False,
line_to_number_buff=SMALL_BUFF,
line_to_number_direction=DL,
),
y_axis_config: dict = dict(
line_to_number_direction=DL,
),
height: float = 8.0,
width: float = 16.0,
background_line_style: dict = dict(
stroke_color=BLUE_D,
stroke_width=2,
stroke_opacity=1,
),
# Defaults to a faded version of line_config
faded_line_style: dict = dict(),
faded_line_ratio: int = 4,
make_smooth_after_applying_functions: bool = True,
**kwargs **kwargs
): ):
super().__init__(x_range, y_range, **kwargs) super().__init__(
x_range, y_range,
height=height,
width=width,
axis_config=axis_config,
y_axis_config=y_axis_config,
**kwargs
)
self.background_line_style = background_line_style
self.faded_line_style = faded_line_style
self.faded_line_ratio = faded_line_ratio
self.make_smooth_after_applying_functions = make_smooth_after_applying_functions
self.init_background_lines() self.init_background_lines()
def init_background_lines(self) -> None: def init_background_lines(self) -> None:
if self.faded_line_style is None: if not self.faded_line_style:
style = dict(self.background_line_style) style = dict(self.background_line_style)
# For anything numerical, like stroke_width # For anything numerical, like stroke_width
# and stroke_opacity, chop it in half # and stroke_opacity, chop it in half
@ -666,23 +659,18 @@ class NumberPlane(Axes):
class ComplexPlane(NumberPlane): class ComplexPlane(NumberPlane):
CONFIG = { def number_to_point(self, number: complex | float) -> np_vector:
"color": BLUE,
"line_frequency": 1,
}
def number_to_point(self, number: complex | float) -> np.ndarray:
number = complex(number) number = complex(number)
return self.coords_to_point(number.real, number.imag) return self.coords_to_point(number.real, number.imag)
def n2p(self, number: complex | float) -> np.ndarray: def n2p(self, number: complex | float) -> np_vector:
return self.number_to_point(number) return self.number_to_point(number)
def point_to_number(self, point: np.ndarray) -> complex: def point_to_number(self, point: np_vector) -> complex:
x, y = self.point_to_coords(point) x, y = self.point_to_coords(point)
return complex(x, y) return complex(x, y)
def p2n(self, point: np.ndarray) -> complex: def p2n(self, point: np_vector) -> complex:
return self.point_to_number(point) return self.point_to_number(point)
def get_default_coordinate_values( def get_default_coordinate_values(