Distinguish Vect3 from Vect3Array types

This commit is contained in:
Grant Sanderson 2022-12-17 13:16:48 -08:00
parent 8db20cc460
commit 97f28b34f3
13 changed files with 83 additions and 77 deletions

View file

@ -26,12 +26,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.typing import ManimColor
from manimlib.typing import ManimColor, Vect3
from typing import Sequence
Vect3 = np.ndarray[int, np.dtype[np.float64]]
class CameraFrame(Mobject):
def __init__(
self,

View file

@ -34,7 +34,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Type, TypeVar
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, RangeSpecifier
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
T = TypeVar("T", bound=Mobject)
@ -61,18 +61,18 @@ class CoordinateSystem(ABC):
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod
def coords_to_point(self, *coords: float) -> Vect3:
def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array:
raise Exception("Not implemented")
@abstractmethod
def point_to_coords(self, point: Vect3) -> tuple[float, ...]:
def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]:
raise Exception("Not implemented")
def c2p(self, *coords: float):
def c2p(self, *coords: float) -> Vect3 | Vect3Array:
"""Abbreviation for coords_to_point"""
return self.coords_to_point(*coords)
def p2c(self, point: Vect3):
def p2c(self, point: Vect3) -> tuple[float | VectN, ...]:
"""Abbreviation for point_to_coords"""
return self.point_to_coords(point)
@ -302,8 +302,8 @@ class CoordinateSystem(ABC):
return self.get_h_line(self.i2gp(x, graph), **kwargs)
def get_scatterplot(self,
x_values: Vect3,
y_values: Vect3,
x_values: Vect3Array,
y_values: Vect3Array,
**dot_config):
return DotCloud(self.c2p(x_values, y_values), **dot_config)
@ -449,14 +449,14 @@ class Axes(VGroup, CoordinateSystem):
axis.shift(-axis.n2p(0))
return axis
def coords_to_point(self, *coords: float) -> Vect3:
def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array:
origin = self.x_axis.number_to_point(0)
return origin + sum(
axis.number_to_point(coord) - origin
for axis, coord in zip(self.get_axes(), coords)
)
def point_to_coords(self, point: Vect3) -> tuple[float, ...]:
def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]:
return tuple([
axis.point_to_number(point)
for axis in self.get_axes()

View file

@ -30,7 +30,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from manimlib.typing import ManimColor, Vect3
from manimlib.typing import ManimColor, Vect3, Vect3Array
DEFAULT_DOT_RADIUS = 0.08
@ -926,7 +926,7 @@ class Polygon(VMobject):
super().__init__(**kwargs)
self.set_points_as_corners([*vertices, vertices[0]])
def get_vertices(self) -> list[Vect3]:
def get_vertices(self) -> Vect3Array:
return self.get_start_anchors()
def round_corners(self, radius: float | None = None):

View file

@ -48,10 +48,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Union, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
TimeBasedUpdater = Callable[["Mobject", float], None]
NonTimeUpdater = Callable[["Mobject"], None]
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
Updater = Union[TimeBasedUpdater, NonTimeUpdater]
@ -233,7 +233,7 @@ class Mobject(object):
self.set_points(mobject.get_points())
return self
def get_points(self) -> Vect3:
def get_points(self) -> Vect3Array:
return self.data["points"]
def clear_points(self) -> None:
@ -242,7 +242,7 @@ class Mobject(object):
def get_num_points(self) -> int:
return len(self.data["points"])
def get_all_points(self) -> Vect3:
def get_all_points(self) -> Vect3Array:
if self.submobjects:
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
@ -251,13 +251,13 @@ class Mobject(object):
def has_points(self) -> bool:
return self.get_num_points() > 0
def get_bounding_box(self) -> Vect3:
def get_bounding_box(self) -> Vect3Array:
if self.needs_new_bounding_box:
self.data["bounding_box"] = self.compute_bounding_box()
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def compute_bounding_box(self) -> Vect3:
def compute_bounding_box(self) -> Vect3Array:
all_points = np.vstack([
self.get_points(),
*(
@ -289,9 +289,9 @@ class Mobject(object):
def are_points_touching(
self,
points: Vect3,
points: Vect3Array,
buff: float = 0
) -> bool:
) -> np.ndarray:
bb = self.get_bounding_box()
mins = (bb[0] - buff)
maxs = (bb[2] + buff)
@ -1871,7 +1871,7 @@ class Mobject(object):
)
return self
def get_resized_shader_data_array(self, length: int) -> Vect3:
def get_resized_shader_data_array(self, length: int) -> np.ndarray:
# If possible, try to populate an existing array, rather
# than recreating it each frame
if len(self.shader_data) != length:
@ -1880,7 +1880,7 @@ class Mobject(object):
def read_data_to_shader(
self,
shader_data: Vect3,
shader_data: np.ndarray,
shader_data_key: str,
data_key: str
):

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from manimlib.typing import ManimColor, Vect3, RangeSpecifier
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
class NumberLine(Line):
@ -118,11 +118,11 @@ class NumberLine(Line):
def get_tick_marks(self) -> VGroup:
return self.ticks
def number_to_point(self, number: float | np.ndarray) -> Vect3:
def number_to_point(self, number: float | VectN) -> Vect3 | Vect3Array:
alpha = (number - self.x_min) / (self.x_max - self.x_min)
return outer_interpolate(self.get_start(), self.get_end(), alpha)
def point_to_number(self, point: Vect3) -> float:
def point_to_number(self, point: Vect3 | Vect3Array) -> float | VectN:
points = self.get_points()
start = points[0]
end = points[-1]
@ -133,11 +133,11 @@ class NumberLine(Line):
)
return interpolate(self.x_min, self.x_max, proportion)
def n2p(self, number: float) -> Vect3:
def n2p(self, number: float | VectN) -> Vect3 | Vect3Array:
"""Abbreviation for number_to_point"""
return self.number_to_point(number)
def p2n(self, point: Vect3) -> float:
def p2n(self, point: Vect3 | Vect3Array) -> float | VectN:
"""Abbreviation for point_to_number"""
return self.point_to_number(point)

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy.typing as npt
from typing import Sequence, Tuple
from manimlib.typing import ManimColor, Vect3
from manimlib.typing import ManimColor, Vect3, Vect3Array
DEFAULT_DOT_RADIUS = 0.05
@ -32,7 +32,7 @@ class DotCloud(PMobject):
def __init__(
self,
points: Sequence[Vect3] | None = None,
points: Vect3Array | None = None,
color: ManimColor = GREY_C,
opacity: float = 1.0,
radius: float = DEFAULT_DOT_RADIUS,
@ -160,7 +160,7 @@ class TrueDot(DotCloud):
class GlowDots(DotCloud):
def __init__(
self,
points: Sequence[Vect3] | None = None,
points: Vect3Array | None = None,
color: ManimColor = YELLOW,
radius: float = DEFAULT_GLOW_DOT_RADIUS,
glow_factor: float = 2.0,

View file

@ -13,9 +13,8 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Sequence
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3
from typing import Callable
from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array
class PMobject(Mobject):
@ -32,7 +31,7 @@ class PMobject(Mobject):
self.data[key] = resize_func(self.data[key], size)
return self
def set_points(self, points: Vect3):
def set_points(self, points: Vect3Array):
if len(points) == 0:
points = np.zeros((0, 3))
super().set_points(points)
@ -41,8 +40,8 @@ class PMobject(Mobject):
def add_points(
self,
points: Sequence[Vect3],
rgbas: Vect3 | None = None,
points: Vect3Array,
rgbas: Vect4Array | None = None,
color: ManimColor | None = None,
opacity: float | None = None
):

View file

@ -18,7 +18,7 @@ if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Tuple
from manimlib.camera.camera import Camera
from manimlib.typing import ManimColor, Vect3
from manimlib.typing import ManimColor, Vect3, Vect3Array
class Surface(Mobject):
@ -114,12 +114,12 @@ class Surface(Mobject):
def get_surface_points_and_nudged_points(
self
) -> tuple[Vect3, Vect3, Vect3]:
) -> tuple[Vect3Array, Vect3Array, Vect3Array]:
points = self.get_points()
k = len(points) // 3
return points[:k], points[k:2 * k], points[2 * k:]
def get_unit_normals(self) -> Vect3:
def get_unit_normals(self) -> Vect3Array:
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
normals = np.cross(
(du_points - s_points) / self.epsilon,
@ -150,12 +150,12 @@ class Surface(Mobject):
def get_partial_points_array(
self,
points: Vect3,
points: Vect3Array,
a: float,
b: float,
resolution: Sequence[int],
axis: int
) -> Vect3:
) -> Vect3Array:
if len(points) == 0:
return points
nu, nv = resolution[:2]

View file

@ -42,7 +42,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Tuple
from manimlib.typing import ManimColor, Vect3, Vect4
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
DEFAULT_STROKE_COLOR = GREY_A
DEFAULT_FILL_COLOR = GREY_C
@ -149,7 +149,7 @@ class VMobject(Mobject):
def set_rgba_array(
self,
rgba_array: Vect3,
rgba_array: Vect4Array,
name: str | None = None,
recurse: bool = False
):
@ -397,9 +397,9 @@ class VMobject(Mobject):
# Points
def set_anchors_and_handles(
self,
anchors1: Vect3,
handles: Vect3,
anchors2: Vect3
anchors1: Vect3Array,
handles: Vect3Array,
anchors2: Vect3Array
):
assert(len(anchors1) == len(handles) == len(anchors2))
nppc = self.n_points_per_curve
@ -601,7 +601,7 @@ class VMobject(Mobject):
self.change_anchor_mode("jagged")
return self
def add_subpath(self, points: Sequence[Vect3]):
def add_subpath(self, points: Vect3Array):
assert(len(points) % self.n_points_per_curve == 0)
self.append_points(points)
return self
@ -635,8 +635,8 @@ class VMobject(Mobject):
def get_subpaths_from_points(
self,
points: Sequence[Vect3]
) -> list[Sequence[Vect3]]:
points: Vect3Array
) -> Vect3Array:
nppc = self.n_points_per_curve
diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc]
splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality
@ -647,13 +647,13 @@ class VMobject(Mobject):
# range(nppc, len(points), nppc)
# )
split_indices = [0, *split_indices, len(points)]
return [
return np.array([
points[i1:i2]
for i1, i2 in zip(split_indices, split_indices[1:])
if (i2 - i1) >= nppc
]
])
def get_subpaths(self) -> list[Sequence[Vect3]]:
def get_subpaths(self) -> Vect3Array:
return self.get_subpaths_from_points(self.get_points())
def get_nth_curve_points(self, n: int) -> Vect3:
@ -710,14 +710,14 @@ class VMobject(Mobject):
for i in range(nppc)
]
def get_start_anchors(self) -> list[Vect3]:
def get_start_anchors(self) -> Vect3Array:
return self.get_points()[0::self.n_points_per_curve]
def get_end_anchors(self) -> Vect3:
nppc = self.n_points_per_curve
return self.get_points()[nppc - 1::nppc]
def get_anchors(self) -> Vect3:
def get_anchors(self) -> Vect3Array:
points = self.get_points()
if len(points) == 1:
return points
@ -726,7 +726,7 @@ class VMobject(Mobject):
self.get_end_anchors(),
))))
def get_points_without_null_curves(self, atol: float = 1e-9) -> Vect3:
def get_points_without_null_curves(self, atol: float = 1e-9) -> Vect3Array:
nppc = self.n_points_per_curve
points = self.get_points()
distinct_curves = reduce(op.or_, [
@ -851,7 +851,7 @@ class VMobject(Mobject):
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n: int, points: Vect3):
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array):
nppc = self.n_points_per_curve
if len(points) == 1:
return np.repeat(points, nppc * n, 0)
@ -1022,7 +1022,7 @@ class VMobject(Mobject):
return wrapper
@triggers_refreshed_triangulation
def set_points(self, points: Vect3):
def set_points(self, points: Vect3Array):
super().set_points(points)
return self

View file

@ -23,7 +23,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject
@ -35,7 +35,7 @@ def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[npt.ArrayLike], Vect3]:
) -> Callable[[VectN], Vect3Array]:
rgbs = np.array(get_colormap_list(color_map))
def func(values):
@ -57,9 +57,9 @@ def get_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[T], Vect3]:
) -> Callable[[float], Vect3]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func([value])[0]
return lambda value: vectorized_func(np.array([value]))[0]
def move_along_vector_field(
@ -254,7 +254,7 @@ class StreamLines(VGroup):
lines.append(line)
self.set_submobjects(lines)
def get_start_points(self) -> Vect3:
def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
cs, self.step_multiple,

View file

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Tuple
from typing import Union, Tuple, Annotated, Literal
from colour import Color
import numpy as np
@ -9,10 +9,20 @@ if TYPE_CHECKING:
ManimColor = Union[str, Color, None]
RangeSpecifier = Tuple[float, float, float] | Tuple[float, float]
# TODO, Nothing about these actually specifies length,
# they are so far just about code readability
Vect2 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 2
Vect3 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 3
Vect4 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 4
VectN = np.ndarray[int, np.dtype[np.float64]]
Matrix3x3 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify output size
# These are various alternate names for np.ndarray meant to specify
# certain shapes.
#
# In theory, these annotations could be used to check arrays sizes
# at runtime, but at the moment nothing actually uses them, and
# the names are here primarily to enhance readibility and allow
# for some stronger type checking if numpy has stronger typing
# in the future
FloatArray = np.ndarray[int, np.dtype[np.float64]]
Vect2 = Annotated[FloatArray, Literal[2]]
Vect3 = Annotated[FloatArray, Literal[3]]
Vect4 = Annotated[FloatArray, Literal[4]]
VectN = Annotated[FloatArray, Literal["N"]]
Matrix3x3 = Annotated[FloatArray, Literal[3, 3]]
Vect2Array = Annotated[FloatArray, Literal["N", 2]]
Vect3Array = Annotated[FloatArray, Literal["N", 3]]
Vect4Array = Annotated[FloatArray, Literal["N", 4]]

View file

@ -33,7 +33,7 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2]
def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]:
def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, ...]]:
return zip(*[
[*objects[k:], *objects[:k]]
for k in range(n)

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Sequence, List, Tuple
from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3
from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3, Vect3Array, Vect2Array
def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3:
@ -369,7 +369,7 @@ def norm_squared(v: VectN | List[float]) -> float:
# TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]:
def earclip_triangulation(verts: Vect3Array | Vect2Array, ring_ends: list[int]) -> list[int]:
"""
Returns a list of indices giving a triangulation
of a polygon, potentially with holes