diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 9a2fca68..735f1676 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -26,12 +26,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.shader_wrapper import ShaderWrapper - from manimlib.typing import ManimColor + from manimlib.typing import ManimColor, Vect3 from typing import Sequence - Vect3 = np.ndarray[int, np.dtype[np.float64]] - - class CameraFrame(Mobject): def __init__( self, diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 92e68370..59f2b6ec 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -34,7 +34,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, Type, TypeVar from manimlib.mobject.mobject import Mobject - from manimlib.typing import ManimColor, Vect3, RangeSpecifier + from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier T = TypeVar("T", bound=Mobject) @@ -61,18 +61,18 @@ class CoordinateSystem(ABC): self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick @abstractmethod - def coords_to_point(self, *coords: float) -> Vect3: + def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array: raise Exception("Not implemented") @abstractmethod - def point_to_coords(self, point: Vect3) -> tuple[float, ...]: + def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]: raise Exception("Not implemented") - def c2p(self, *coords: float): + def c2p(self, *coords: float) -> Vect3 | Vect3Array: """Abbreviation for coords_to_point""" return self.coords_to_point(*coords) - def p2c(self, point: Vect3): + def p2c(self, point: Vect3) -> tuple[float | VectN, ...]: """Abbreviation for point_to_coords""" return self.point_to_coords(point) @@ -302,8 +302,8 @@ class CoordinateSystem(ABC): return self.get_h_line(self.i2gp(x, graph), **kwargs) def get_scatterplot(self, - x_values: Vect3, - y_values: Vect3, + x_values: Vect3Array, + y_values: Vect3Array, **dot_config): return DotCloud(self.c2p(x_values, y_values), **dot_config) @@ -449,14 +449,14 @@ class Axes(VGroup, CoordinateSystem): axis.shift(-axis.n2p(0)) return axis - def coords_to_point(self, *coords: float) -> Vect3: + def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array: origin = self.x_axis.number_to_point(0) return origin + sum( axis.number_to_point(coord) - origin for axis, coord in zip(self.get_axes(), coords) ) - def point_to_coords(self, point: Vect3) -> tuple[float, ...]: + def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]: return tuple([ axis.point_to_number(point) for axis in self.get_axes() diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index 1c0204e4..dd6cb200 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -30,7 +30,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Iterable - from manimlib.typing import ManimColor, Vect3 + from manimlib.typing import ManimColor, Vect3, Vect3Array DEFAULT_DOT_RADIUS = 0.08 @@ -926,7 +926,7 @@ class Polygon(VMobject): super().__init__(**kwargs) self.set_points_as_corners([*vertices, vertices[0]]) - def get_vertices(self) -> list[Vect3]: + def get_vertices(self) -> Vect3Array: return self.get_start_anchors() def round_corners(self, radius: float | None = None): diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index dd15b3a9..5a4aa40e 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -48,10 +48,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, Union, Tuple import numpy.typing as npt - from manimlib.typing import ManimColor, Vect3, Vect4 + from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array - TimeBasedUpdater = Callable[["Mobject", float], None] - NonTimeUpdater = Callable[["Mobject"], None] + TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None] + NonTimeUpdater = Callable[["Mobject"], "Mobject" | None] Updater = Union[TimeBasedUpdater, NonTimeUpdater] @@ -233,7 +233,7 @@ class Mobject(object): self.set_points(mobject.get_points()) return self - def get_points(self) -> Vect3: + def get_points(self) -> Vect3Array: return self.data["points"] def clear_points(self) -> None: @@ -242,7 +242,7 @@ class Mobject(object): def get_num_points(self) -> int: return len(self.data["points"]) - def get_all_points(self) -> Vect3: + def get_all_points(self) -> Vect3Array: if self.submobjects: return np.vstack([sm.get_points() for sm in self.get_family()]) else: @@ -251,13 +251,13 @@ class Mobject(object): def has_points(self) -> bool: return self.get_num_points() > 0 - def get_bounding_box(self) -> Vect3: + def get_bounding_box(self) -> Vect3Array: if self.needs_new_bounding_box: self.data["bounding_box"] = self.compute_bounding_box() self.needs_new_bounding_box = False return self.data["bounding_box"] - def compute_bounding_box(self) -> Vect3: + def compute_bounding_box(self) -> Vect3Array: all_points = np.vstack([ self.get_points(), *( @@ -289,9 +289,9 @@ class Mobject(object): def are_points_touching( self, - points: Vect3, + points: Vect3Array, buff: float = 0 - ) -> bool: + ) -> np.ndarray: bb = self.get_bounding_box() mins = (bb[0] - buff) maxs = (bb[2] + buff) @@ -1871,7 +1871,7 @@ class Mobject(object): ) return self - def get_resized_shader_data_array(self, length: int) -> Vect3: + def get_resized_shader_data_array(self, length: int) -> np.ndarray: # If possible, try to populate an existing array, rather # than recreating it each frame if len(self.shader_data) != length: @@ -1880,7 +1880,7 @@ class Mobject(object): def read_data_to_shader( self, - shader_data: Vect3, + shader_data: np.ndarray, shader_data_key: str, data_key: str ): diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index 999abc4b..e8149f07 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Iterable - from manimlib.typing import ManimColor, Vect3, RangeSpecifier + from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier class NumberLine(Line): @@ -118,11 +118,11 @@ class NumberLine(Line): def get_tick_marks(self) -> VGroup: return self.ticks - def number_to_point(self, number: float | np.ndarray) -> Vect3: + def number_to_point(self, number: float | VectN) -> Vect3 | Vect3Array: alpha = (number - self.x_min) / (self.x_max - self.x_min) return outer_interpolate(self.get_start(), self.get_end(), alpha) - def point_to_number(self, point: Vect3) -> float: + def point_to_number(self, point: Vect3 | Vect3Array) -> float | VectN: points = self.get_points() start = points[0] end = points[-1] @@ -133,11 +133,11 @@ class NumberLine(Line): ) return interpolate(self.x_min, self.x_max, proportion) - def n2p(self, number: float) -> Vect3: + def n2p(self, number: float | VectN) -> Vect3 | Vect3Array: """Abbreviation for number_to_point""" return self.number_to_point(number) - def p2n(self, point: Vect3) -> float: + def p2n(self, point: Vect3 | Vect3Array) -> float | VectN: """Abbreviation for point_to_number""" return self.point_to_number(point) diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index b90cdf27..aa0f3cb4 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy.typing as npt from typing import Sequence, Tuple - from manimlib.typing import ManimColor, Vect3 + from manimlib.typing import ManimColor, Vect3, Vect3Array DEFAULT_DOT_RADIUS = 0.05 @@ -32,7 +32,7 @@ class DotCloud(PMobject): def __init__( self, - points: Sequence[Vect3] | None = None, + points: Vect3Array | None = None, color: ManimColor = GREY_C, opacity: float = 1.0, radius: float = DEFAULT_DOT_RADIUS, @@ -160,7 +160,7 @@ class TrueDot(DotCloud): class GlowDots(DotCloud): def __init__( self, - points: Sequence[Vect3] | None = None, + points: Vect3Array | None = None, color: ManimColor = YELLOW, radius: float = DEFAULT_GLOW_DOT_RADIUS, glow_factor: float = 2.0, diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index e829c349..d4df879e 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -13,9 +13,8 @@ from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Sequence - import numpy.typing as npt - from manimlib.typing import ManimColor, Vect3 + from typing import Callable + from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array class PMobject(Mobject): @@ -32,7 +31,7 @@ class PMobject(Mobject): self.data[key] = resize_func(self.data[key], size) return self - def set_points(self, points: Vect3): + def set_points(self, points: Vect3Array): if len(points) == 0: points = np.zeros((0, 3)) super().set_points(points) @@ -41,8 +40,8 @@ class PMobject(Mobject): def add_points( self, - points: Sequence[Vect3], - rgbas: Vect3 | None = None, + points: Vect3Array, + rgbas: Vect4Array | None = None, color: ManimColor | None = None, opacity: float | None = None ): diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 0d21012d..ea3dabec 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, Tuple from manimlib.camera.camera import Camera - from manimlib.typing import ManimColor, Vect3 + from manimlib.typing import ManimColor, Vect3, Vect3Array class Surface(Mobject): @@ -114,12 +114,12 @@ class Surface(Mobject): def get_surface_points_and_nudged_points( self - ) -> tuple[Vect3, Vect3, Vect3]: + ) -> tuple[Vect3Array, Vect3Array, Vect3Array]: points = self.get_points() k = len(points) // 3 return points[:k], points[k:2 * k], points[2 * k:] - def get_unit_normals(self) -> Vect3: + def get_unit_normals(self) -> Vect3Array: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() normals = np.cross( (du_points - s_points) / self.epsilon, @@ -150,12 +150,12 @@ class Surface(Mobject): def get_partial_points_array( self, - points: Vect3, + points: Vect3Array, a: float, b: float, resolution: Sequence[int], axis: int - ) -> Vect3: + ) -> Vect3Array: if len(points) == 0: return points nu, nv = resolution[:2] diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 802dbe0a..ded26245 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -42,7 +42,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, Tuple - from manimlib.typing import ManimColor, Vect3, Vect4 + from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array DEFAULT_STROKE_COLOR = GREY_A DEFAULT_FILL_COLOR = GREY_C @@ -149,7 +149,7 @@ class VMobject(Mobject): def set_rgba_array( self, - rgba_array: Vect3, + rgba_array: Vect4Array, name: str | None = None, recurse: bool = False ): @@ -397,9 +397,9 @@ class VMobject(Mobject): # Points def set_anchors_and_handles( self, - anchors1: Vect3, - handles: Vect3, - anchors2: Vect3 + anchors1: Vect3Array, + handles: Vect3Array, + anchors2: Vect3Array ): assert(len(anchors1) == len(handles) == len(anchors2)) nppc = self.n_points_per_curve @@ -601,7 +601,7 @@ class VMobject(Mobject): self.change_anchor_mode("jagged") return self - def add_subpath(self, points: Sequence[Vect3]): + def add_subpath(self, points: Vect3Array): assert(len(points) % self.n_points_per_curve == 0) self.append_points(points) return self @@ -635,8 +635,8 @@ class VMobject(Mobject): def get_subpaths_from_points( self, - points: Sequence[Vect3] - ) -> list[Sequence[Vect3]]: + points: Vect3Array + ) -> Vect3Array: nppc = self.n_points_per_curve diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc] splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality @@ -647,13 +647,13 @@ class VMobject(Mobject): # range(nppc, len(points), nppc) # ) split_indices = [0, *split_indices, len(points)] - return [ + return np.array([ points[i1:i2] for i1, i2 in zip(split_indices, split_indices[1:]) if (i2 - i1) >= nppc - ] + ]) - def get_subpaths(self) -> list[Sequence[Vect3]]: + def get_subpaths(self) -> Vect3Array: return self.get_subpaths_from_points(self.get_points()) def get_nth_curve_points(self, n: int) -> Vect3: @@ -710,14 +710,14 @@ class VMobject(Mobject): for i in range(nppc) ] - def get_start_anchors(self) -> list[Vect3]: + def get_start_anchors(self) -> Vect3Array: return self.get_points()[0::self.n_points_per_curve] def get_end_anchors(self) -> Vect3: nppc = self.n_points_per_curve return self.get_points()[nppc - 1::nppc] - def get_anchors(self) -> Vect3: + def get_anchors(self) -> Vect3Array: points = self.get_points() if len(points) == 1: return points @@ -726,7 +726,7 @@ class VMobject(Mobject): self.get_end_anchors(), )))) - def get_points_without_null_curves(self, atol: float = 1e-9) -> Vect3: + def get_points_without_null_curves(self, atol: float = 1e-9) -> Vect3Array: nppc = self.n_points_per_curve points = self.get_points() distinct_curves = reduce(op.or_, [ @@ -851,7 +851,7 @@ class VMobject(Mobject): mob.set_points(new_points) return self - def insert_n_curves_to_point_list(self, n: int, points: Vect3): + def insert_n_curves_to_point_list(self, n: int, points: Vect3Array): nppc = self.n_points_per_curve if len(points) == 1: return np.repeat(points, nppc * n, 0) @@ -1022,7 +1022,7 @@ class VMobject(Mobject): return wrapper @triggers_refreshed_triangulation - def set_points(self, points: Vect3): + def set_points(self, points: Vect3Array): super().set_points(points) return self diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index a1a51990..6634539c 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, TypeVar, Tuple import numpy.typing as npt - from manimlib.typing import ManimColor, Vect3 + from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.mobject import Mobject @@ -35,7 +35,7 @@ def get_vectorized_rgb_gradient_function( min_value: T, max_value: T, color_map: str -) -> Callable[[npt.ArrayLike], Vect3]: +) -> Callable[[VectN], Vect3Array]: rgbs = np.array(get_colormap_list(color_map)) def func(values): @@ -57,9 +57,9 @@ def get_rgb_gradient_function( min_value: T, max_value: T, color_map: str -) -> Callable[[T], Vect3]: +) -> Callable[[float], Vect3]: vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map) - return lambda value: vectorized_func([value])[0] + return lambda value: vectorized_func(np.array([value]))[0] def move_along_vector_field( @@ -254,7 +254,7 @@ class StreamLines(VGroup): lines.append(line) self.set_submobjects(lines) - def get_start_points(self) -> Vect3: + def get_start_points(self) -> Vect3Array: cs = self.coordinate_system sample_coords = get_sample_points_from_coordinate_system( cs, self.step_multiple, diff --git a/manimlib/typing.py b/manimlib/typing.py index 94811660..6764b412 100644 --- a/manimlib/typing.py +++ b/manimlib/typing.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union, Tuple + from typing import Union, Tuple, Annotated, Literal from colour import Color import numpy as np @@ -9,10 +9,20 @@ if TYPE_CHECKING: ManimColor = Union[str, Color, None] RangeSpecifier = Tuple[float, float, float] | Tuple[float, float] - # TODO, Nothing about these actually specifies length, - # they are so far just about code readability - Vect2 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 2 - Vect3 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 3 - Vect4 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify length of 4 - VectN = np.ndarray[int, np.dtype[np.float64]] - Matrix3x3 = np.ndarray[int, np.dtype[np.float64]] # TODO, specify output size \ No newline at end of file + # These are various alternate names for np.ndarray meant to specify + # certain shapes. + # + # In theory, these annotations could be used to check arrays sizes + # at runtime, but at the moment nothing actually uses them, and + # the names are here primarily to enhance readibility and allow + # for some stronger type checking if numpy has stronger typing + # in the future + FloatArray = np.ndarray[int, np.dtype[np.float64]] + Vect2 = Annotated[FloatArray, Literal[2]] + Vect3 = Annotated[FloatArray, Literal[3]] + Vect4 = Annotated[FloatArray, Literal[4]] + VectN = Annotated[FloatArray, Literal["N"]] + Matrix3x3 = Annotated[FloatArray, Literal[3, 3]] + Vect2Array = Annotated[FloatArray, Literal["N", 2]] + Vect3Array = Annotated[FloatArray, Literal["N", 3]] + Vect4Array = Annotated[FloatArray, Literal["N", 4]] diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index ed11fcbb..bcde11f8 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -33,7 +33,7 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]: +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, ...]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 8a6acd07..76d207bd 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Sequence, List, Tuple - from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3 + from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3, Vect3Array, Vect2Array def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3: @@ -369,7 +369,7 @@ def norm_squared(v: VectN | List[float]) -> float: # TODO, fails for polygons drawn over themselves -def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]: +def earclip_triangulation(verts: Vect3Array | Vect2Array, ring_ends: list[int]) -> list[int]: """ Returns a list of indices giving a triangulation of a polygon, potentially with holes