From 19187ead06deeff0a71e2393e8e82280becb2983 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 18:56:50 +0800 Subject: [PATCH] chore: add type hints to manimlib.mobject.types --- manimlib/mobject/types/dot_cloud.py | 54 ++-- manimlib/mobject/types/image_mobject.py | 17 +- manimlib/mobject/types/point_cloud_mobject.py | 44 ++- manimlib/mobject/types/surface.py | 83 ++++-- manimlib/mobject/types/vectorized_mobject.py | 272 +++++++++++------- 5 files changed, 306 insertions(+), 164 deletions(-) diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 03511ecb..d44bdc25 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import numpy as np +import numpy.typing as npt import moderngl from manimlib.constants import GREY_C @@ -29,27 +32,31 @@ class DotCloud(PMobject): ], } - def __init__(self, points=None, **kwargs): + def __init__(self, points: npt.ArrayLike = None, **kwargs): super().__init__(**kwargs) if points is not None: self.set_points(points) - def init_data(self): + def init_data(self) -> None: super().init_data() self.data["radii"] = np.zeros((1, 1)) self.set_radius(self.radius) - def init_uniforms(self): + def init_uniforms(self) -> None: super().init_uniforms() self.uniforms["glow_factor"] = self.glow_factor - def to_grid(self, n_rows, n_cols, n_layers=1, - buff_ratio=None, - h_buff_ratio=1.0, - v_buff_ratio=1.0, - d_buff_ratio=1.0, - height=DEFAULT_GRID_HEIGHT, - ): + def to_grid( + self, + n_rows: int, + n_cols: int, + n_layers: int = 1, + buff_ratio: float | None = None, + h_buff_ratio: float = 1.0, + v_buff_ratio: float = 1.0, + d_buff_ratio: float = 1.0, + height: float = DEFAULT_GRID_HEIGHT, + ): n_points = n_rows * n_cols * n_layers points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3)) points[:, 0] = points[:, 0] % n_cols @@ -74,50 +81,55 @@ class DotCloud(PMobject): self.center() return self - def set_radii(self, radii): + def set_radii(self, radii: npt.ArrayLike): n_points = len(self.get_points()) radii = np.array(radii).reshape((len(radii), 1)) self.data["radii"] = resize_preserving_order(radii, n_points) self.refresh_bounding_box() return self - def get_radii(self): + def get_radii(self) -> np.ndarray: return self.data["radii"] - def set_radius(self, radius): + def set_radius(self, radius: float): self.data["radii"][:] = radius self.refresh_bounding_box() return self - def get_radius(self): + def get_radius(self) -> float: return self.get_radii().max() - def set_glow_factor(self, glow_factor): + def set_glow_factor(self, glow_factor: float) -> None: self.uniforms["glow_factor"] = glow_factor - def get_glow_factor(self): + def get_glow_factor(self) -> float: return self.uniforms["glow_factor"] - def compute_bounding_box(self): + def compute_bounding_box(self) -> np.ndarray: bb = super().compute_bounding_box() radius = self.get_radius() bb[0] += np.full((3,), -radius) bb[2] += np.full((3,), radius) return bb - def scale(self, scale_factor, scale_radii=True, **kwargs): + def scale( + self, + scale_factor: float | npt.ArrayLike, + scale_radii: bool = True, + **kwargs + ): super().scale(scale_factor, **kwargs) if scale_radii: self.set_radii(scale_factor * self.get_radii()) return self - def make_3d(self, reflectiveness=0.5, shadow=0.2): + def make_3d(self, reflectiveness: float = 0.5, shadow: float = 0.2): self.set_reflectiveness(reflectiveness) self.set_shadow(shadow) self.apply_depth_test() return self - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: shader_data = super().get_shader_data() self.read_data_to_shader(shader_data, "radius", "radii") self.read_data_to_shader(shader_data, "color", "rgbas") @@ -125,7 +137,7 @@ class DotCloud(PMobject): class TrueDot(DotCloud): - def __init__(self, center=ORIGIN, **kwargs): + def __init__(self, center: np.ndarray = ORIGIN, **kwargs): super().__init__(points=[center], **kwargs) diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 334b389d..d3f11f2b 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -1,5 +1,6 @@ -import numpy as np +from __future__ import annotations +import numpy as np from PIL import Image from manimlib.constants import * @@ -21,33 +22,33 @@ class ImageMobject(Mobject): ] } - def __init__(self, filename, **kwargs): + def __init__(self, filename: str, **kwargs): self.set_image_path(get_full_raster_image_path(filename)) super().__init__(**kwargs) - def set_image_path(self, path): + def set_image_path(self, path: str) -> None: self.path = path self.image = Image.open(path) self.texture_paths = {"Texture": path} - def init_data(self): + def init_data(self) -> None: self.data = { "points": np.array([UL, DL, UR, DR]), "im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]), "opacity": np.array([[self.opacity]], dtype=np.float32), } - def init_points(self): + def init_points(self) -> None: size = self.image.size self.set_width(2 * size[0] / size[1], stretch=True) self.set_height(self.height) - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): for mob in self.get_family(recurse): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self - def point_to_rgb(self, point): + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2] x_alpha = inverse_interpolate(x0, x1, point[0]) @@ -63,7 +64,7 @@ class ImageMobject(Mobject): )) return np.array(rgb) / 255 - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: shader_data = super().get_shader_data() self.read_data_to_shader(shader_data, "im_coords", "im_coords") self.read_data_to_shader(shader_data, "opacity", "opacity") diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 28ccee7e..2af3e191 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +from typing import Callable, Sequence, Union + +import colour +import numpy.typing as npt + from manimlib.constants import * from manimlib.mobject.mobject import Mobject from manimlib.utils.color import color_gradient @@ -6,26 +13,39 @@ from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_array +Color = Union[str, colour.Color, Sequence[float]] + + class PMobject(Mobject): CONFIG = { "opacity": 1.0, } - def resize_points(self, size, resize_func=resize_array): + def resize_points( + self, + size: int, + resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array + ): # TODO for key in self.data: if key == "bounding_box": continue if len(self.data[key]) != size: - self.data[key] = resize_array(self.data[key], size) + self.data[key] = resize_func(self.data[key], size) return self - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): super().set_points(points) self.resize_points(len(points)) return self - def add_points(self, points, rgbas=None, color=None, opacity=None): + def add_points( + self, + points: npt.ArrayLike, + rgbas: np.ndarray | None = None, + color: Color | None = None, + opacity: float | None = None + ): """ points must be a Nx3 numpy array, as must rgbas if it is not None """ @@ -44,20 +64,20 @@ class PMobject(Mobject): self.data["rgbas"][-len(new_rgbas):] = new_rgbas return self - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self, *colors: Color): self.data["rgbas"] = np.array(list(map( color_to_rgba, color_gradient(colors, self.get_num_points()) ))) return self - def match_colors(self, pmobject): + def match_colors(self, pmobject: "PMobject"): self.data["rgbas"][:] = resize_with_interpolation( pmobject.data["rgbas"], self.get_num_points() ) return self - def filter_out(self, condition): + def filter_out(self, condition: Callable[[np.ndarray], bool]): for mob in self.family_members_with_points(): to_keep = ~np.apply_along_axis(condition, 1, mob.get_points()) for key in mob.data: @@ -66,7 +86,7 @@ class PMobject(Mobject): mob.data[key] = mob.data[key][to_keep] return self - def sort_points(self, function=lambda p: p[0]): + def sort_points(self, function: Callable[[np.ndarray]] = lambda p: p[0]): """ function is any map from R^3 to R """ @@ -86,11 +106,11 @@ class PMobject(Mobject): ]) return self - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: index = alpha * (self.get_num_points() - 1) return self.get_points()[int(index)] - def pointwise_become_partial(self, pmobject, a, b): + def pointwise_become_partial(self, pmobject: "PMobject", a: float, b: float): lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) for key in self.data: @@ -101,7 +121,7 @@ class PMobject(Mobject): class PGroup(PMobject): - def __init__(self, *pmobs, **kwargs): + def __init__(self, *pmobs: PMobject, **kwargs): if not all([isinstance(m, PMobject) for m in pmobs]): raise Exception("All submobjects must be of type PMobject") super().__init__(*pmobs, **kwargs) @@ -112,6 +132,6 @@ class Point(PMobject): "color": BLACK, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: np.ndarray = ORIGIN, **kwargs): super().__init__(**kwargs) self.add_points([location]) diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 1160c1ae..a8b4fd5c 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -1,7 +1,13 @@ -import numpy as np +from __future__ import annotations + +from typing import Iterable, Callable + import moderngl +import numpy as np +import numpy.typing as npt from manimlib.constants import * +from manimlib.camera.camera import Camera from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -42,7 +48,7 @@ class Surface(Mobject): super().__init__(**kwargs) self.compute_triangle_indices() - def uv_func(self, u, v): + def uv_func(self, u: float, v: float) -> tuple[float, float, float]: # To be implemented in subclasses return (u, v, 0.0) @@ -85,15 +91,17 @@ class Surface(Mobject): indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right self.triangle_indices = indices - def get_triangle_indices(self): + def get_triangle_indices(self) -> np.ndarray: return self.triangle_indices - def get_surface_points_and_nudged_points(self): + def get_surface_points_and_nudged_points( + self + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: points = self.get_points() k = len(points) // 3 return points[:k], points[k:2 * k], points[2 * k:] - def get_unit_normals(self): + def get_unit_normals(self) -> np.ndarray: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() normals = np.cross( (du_points - s_points) / self.epsilon, @@ -101,7 +109,13 @@ class Surface(Mobject): ) return normalize_along_axis(normals, 1) - def pointwise_become_partial(self, smobject, a, b, axis=None): + def pointwise_become_partial( + self, + smobject: "Surface", + a: float, + b: float, + axis: np.ndarray | None = None + ): assert(isinstance(smobject, Surface)) if axis is None: axis = self.prefered_creation_axis @@ -116,7 +130,14 @@ class Surface(Mobject): ])) return self - def get_partial_points_array(self, points, a, b, resolution, axis): + def get_partial_points_array( + self, + points: np.ndarray, + a: float, + b: float, + resolution: npt.ArrayLike, + axis: int + ) -> np.ndarray: if len(points) == 0: return points nu, nv = resolution[:2] @@ -149,7 +170,7 @@ class Surface(Mobject): ).reshape(shape) return points.reshape((nu * nv, *resolution[2:])) - def sort_faces_back_to_front(self, vect=OUT): + def sort_faces_back_to_front(self, vect: np.ndarray = OUT): tri_is = self.triangle_indices indices = list(range(len(tri_is) // 3)) points = self.get_points() @@ -162,13 +183,13 @@ class Surface(Mobject): tri_is[k::3] = tri_is[k::3][indices] return self - def always_sort_to_camera(self, camera): + def always_sort_to_camera(self, camera: Camera): self.add_updater(lambda m: m.sort_faces_back_to_front( camera.get_location() - self.get_center() )) # For shaders - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() shader_data = self.get_resized_shader_data_array(len(s_points)) if "points" not in self.locked_data_keys: @@ -178,16 +199,22 @@ class Surface(Mobject): self.fill_in_shader_color_info(shader_data) return shader_data - def fill_in_shader_color_info(self, shader_data): + def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray: self.read_data_to_shader(shader_data, "color", "rgbas") return shader_data - def get_shader_vert_indices(self): + def get_shader_vert_indices(self) -> np.ndarray: return self.get_triangle_indices() class ParametricSurface(Surface): - def __init__(self, uv_func, u_range=(0, 1), v_range=(0, 1), **kwargs): + def __init__( + self, + uv_func: Callable[[float, float], Iterable[float]], + u_range: tuple[float, float] = (0, 1), + v_range: tuple[float, float] = (0, 1), + **kwargs + ): self.passed_uv_func = uv_func super().__init__(u_range=u_range, v_range=v_range, **kwargs) @@ -200,7 +227,7 @@ class SGroup(Surface): "resolution": (0, 0), } - def __init__(self, *parametric_surfaces, **kwargs): + def __init__(self, *parametric_surfaces: Surface, **kwargs): super().__init__(uv_func=None, **kwargs) self.add(*parametric_surfaces) @@ -220,7 +247,13 @@ class TexturedSurface(Surface): ] } - def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs): + def __init__( + self, + uv_surface: Surface, + image_file: str, + dark_image_file: str | None = None, + **kwargs + ): if not isinstance(uv_surface, Surface): raise Exception("uv_surface must be of type Surface") # Set texture information @@ -236,10 +269,10 @@ class TexturedSurface(Surface): self.uv_surface = uv_surface self.uv_func = uv_surface.uv_func - self.u_range = uv_surface.u_range - self.v_range = uv_surface.v_range - self.resolution = uv_surface.resolution - self.gloss = self.uv_surface.gloss + self.u_range: tuple[float, float] = uv_surface.u_range + self.v_range: tuple[float, float] = uv_surface.v_range + self.resolution: tuple[float, float] = uv_surface.resolution + self.gloss: float = self.uv_surface.gloss super().__init__(**kwargs) def init_data(self): @@ -263,12 +296,18 @@ class TexturedSurface(Surface): def init_colors(self): self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): for mob in self.get_family(recurse): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self - def pointwise_become_partial(self, tsmobject, a, b, axis=1): + def pointwise_become_partial( + self, + tsmobject: "TexturedSurface", + a: float, + b: float, + axis: int = 1 + ): super().pointwise_become_partial(tsmobject, a, b, axis) im_coords = self.data["im_coords"] im_coords[:] = tsmobject.data["im_coords"] @@ -280,7 +319,7 @@ class TexturedSurface(Surface): ) return self - def fill_in_shader_color_info(self, shader_data): + def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray: self.read_data_to_shader(shader_data, "opacity", "opacity") self.read_data_to_shader(shader_data, "im_coords", "im_coords") return shader_data diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 3c7a4326..a1a6c29f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1,8 +1,13 @@ -import itertools as it -import operator as op -import moderngl +from __future__ import annotations +import operator as op +import itertools as it from functools import reduce, wraps +from typing import Iterable, Sequence, Callable, Union + +import colour +import moderngl +import numpy.typing as npt from manimlib.constants import * from manimlib.mobject.mobject import Mobject @@ -29,6 +34,9 @@ from manimlib.utils.space_ops import z_to_vector from manimlib.shader_wrapper import ShaderWrapper +Color = Union[str, colour.Color, Sequence[float]] + + class VMobject(Mobject): CONFIG = { "fill_color": None, @@ -105,7 +113,12 @@ class VMobject(Mobject): self.set_flat_stroke(self.flat_stroke) return self - def set_rgba_array(self, rgba_array, name=None, recurse=False): + def set_rgba_array( + self, + rgba_array: npt.ArrayLike, + name: str = None, + recurse: bool = False + ): if name is None: names = ["fill_rgba", "stroke_rgba"] else: @@ -115,11 +128,23 @@ class VMobject(Mobject): super().set_rgba_array(rgba_array, name, recurse) return self - def set_fill(self, color=None, opacity=None, recurse=True): + def set_fill( + self, + color: Color | None = None, + opacity: float | None = None, + recurse: bool = True + ): self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) return self - def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True): + def set_stroke( + self, + color: Color | None = None, + width: float | npt.ArrayLike | None = None, + opacity: float | None = None, + background: bool | None = None, + recurse: bool = True + ): self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) if width is not None: @@ -135,29 +160,36 @@ class VMobject(Mobject): mob.draw_stroke_behind_fill = background return self - def set_backstroke(self, color=BLACK, width=3, background=True): + def set_backstroke( + self, + color: Color = BLACK, + width: float | npt.ArrayLike = 3, + background: bool = True + ): self.set_stroke(color, width, background=background) return self - def align_stroke_width_data_to_points(self, recurse=True): + def align_stroke_width_data_to_points(self, recurse: bool = True) -> None: for mob in self.get_family(recurse): mob.data["stroke_width"] = resize_with_interpolation( mob.data["stroke_width"], len(mob.get_points()) ) - def set_style(self, - fill_color=None, - fill_opacity=None, - fill_rgba=None, - stroke_color=None, - stroke_opacity=None, - stroke_rgba=None, - stroke_width=None, - stroke_background=True, - reflectiveness=None, - gloss=None, - shadow=None, - recurse=True): + def set_style( + self, + fill_color: Color | None = None, + fill_opacity: float | None = None, + fill_rgba: npt.ArrayLike | None = None, + stroke_color: Color | None = None, + stroke_opacity: float | None = None, + stroke_rgba: npt.ArrayLike | None = None, + stroke_width: float | npt.ArrayLike | None = None, + stroke_background: bool = True, + reflectiveness: float | None = None, + gloss: float | None = None, + shadow: float | None = None, + recurse: bool = True + ): if fill_rgba is not None: self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba)) else: @@ -201,7 +233,7 @@ class VMobject(Mobject): "shadow": self.get_shadow(), } - def match_style(self, vmobject, recurse=True): + def match_style(self, vmobject: "VMobject", recurse: bool = True): self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -215,17 +247,17 @@ class VMobject(Mobject): sm1.match_style(sm2) return self - def set_color(self, color, recurse=True): + def set_color(self, color: Color, recurse: bool = True): self.set_fill(color, recurse=recurse) self.set_stroke(color, recurse=recurse) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def fade(self, darkness=0.5, recurse=True): + def fade(self, darkness: float = 0.5, recurse: bool = True): mobs = self.get_family() if recurse else [self] for mob in mobs: factor = 1.0 - darkness @@ -239,78 +271,83 @@ class VMobject(Mobject): ) return self - def get_fill_colors(self): + def get_fill_colors(self) -> list[str]: return [ rgb_to_hex(rgba[:3]) for rgba in self.data['fill_rgba'] ] - def get_fill_opacities(self): + def get_fill_opacities(self) -> np.ndarray: return self.data['fill_rgba'][:, 3] - def get_stroke_colors(self): + def get_stroke_colors(self) -> list[str]: return [ rgb_to_hex(rgba[:3]) for rgba in self.data['stroke_rgba'] ] - def get_stroke_opacities(self): + def get_stroke_opacities(self) -> np.ndarray: return self.data['stroke_rgba'][:, 3] - def get_stroke_widths(self): + def get_stroke_widths(self) -> np.ndarray: return self.data['stroke_width'][:, 0] # TODO, it's weird for these to return the first of various lists # rather than the full information - def get_fill_color(self): + def get_fill_color(self) -> str: """ If there are multiple colors (for gradient) this returns the first one """ return self.get_fill_colors()[0] - def get_fill_opacity(self): + def get_fill_opacity(self) -> float: """ If there are multiple opacities, this returns the first """ return self.get_fill_opacities()[0] - def get_stroke_color(self): + def get_stroke_color(self) -> str: return self.get_stroke_colors()[0] - def get_stroke_width(self): + def get_stroke_width(self) -> float | np.ndarray: return self.get_stroke_widths()[0] - def get_stroke_opacity(self): + def get_stroke_opacity(self) -> float: return self.get_stroke_opacities()[0] - def get_color(self): + def get_color(self) -> str: if self.has_fill(): return self.get_fill_color() return self.get_stroke_color() - def has_stroke(self): + def has_stroke(self) -> bool: return self.get_stroke_widths().any() and self.get_stroke_opacities().any() - def has_fill(self): + def has_fill(self) -> bool: return any(self.get_fill_opacities()) - def get_opacity(self): + def get_opacity(self) -> float: if self.has_fill(): return self.get_fill_opacity() return self.get_stroke_opacity() - def set_flat_stroke(self, flat_stroke=True, recurse=True): + def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): for mob in self.get_family(recurse): mob.flat_stroke = flat_stroke return self - def get_flat_stroke(self): + def get_flat_stroke(self) -> bool: return self.flat_stroke # Points - def set_anchors_and_handles(self, anchors1, handles, anchors2): + def set_anchors_and_handles( + self, + anchors1: np.ndarray, + handles: np.ndarray, + anchors2: np.ndarray + ): assert(len(anchors1) == len(handles) == len(anchors2)) nppc = self.n_points_per_curve new_points = np.zeros((nppc * len(anchors1), self.dim)) @@ -320,16 +357,27 @@ class VMobject(Mobject): self.set_points(new_points) return self - def start_new_path(self, point): + def start_new_path(self, point: np.ndarray): assert(self.get_num_points() % self.n_points_per_curve == 0) self.append_points([point]) return self - def add_cubic_bezier_curve(self, anchor1, handle1, handle2, anchor2): + def add_cubic_bezier_curve( + self, + anchor1: npt.ArrayLike, + handle1: npt.ArrayLike, + handle2: npt.ArrayLike, + anchor2: npt.ArrayLike + ): new_points = get_quadratic_approximation_of_cubic(anchor1, handle1, handle2, anchor2) self.append_points(new_points) - def add_cubic_bezier_curve_to(self, handle1, handle2, anchor): + def add_cubic_bezier_curve_to( + self, + handle1: npt.ArrayLike, + handle2: npt.ArrayLike, + anchor: npt.ArrayLike + ): """ Add cubic bezier curve to the path. """ @@ -342,14 +390,14 @@ class VMobject(Mobject): else: self.append_points(quadratic_approx) - def add_quadratic_bezier_curve_to(self, handle, anchor): + def add_quadratic_bezier_curve_to(self, handle: np.ndarray, anchor: np.ndarray): self.throw_error_if_no_points() if self.has_new_path_started(): self.append_points([handle, anchor]) else: self.append_points([self.get_last_point(), handle, anchor]) - def add_line_to(self, point): + def add_line_to(self, point: np.ndarray): end = self.get_points()[-1] alphas = np.linspace(0, 1, self.n_points_per_curve) if self.long_lines: @@ -371,7 +419,7 @@ class VMobject(Mobject): self.append_points(points) return self - def add_smooth_curve_to(self, point): + def add_smooth_curve_to(self, point: np.ndarray): if self.has_new_path_started(): self.add_line_to(point) else: @@ -380,7 +428,7 @@ class VMobject(Mobject): self.add_quadratic_bezier_curve_to(new_handle, point) return self - def add_smooth_cubic_curve_to(self, handle, point): + def add_smooth_cubic_curve_to(self, handle: np.ndarray, point: np.ndarray): self.throw_error_if_no_points() if self.get_num_points() == 1: new_handle = self.get_points()[-1] @@ -388,13 +436,13 @@ class VMobject(Mobject): new_handle = self.get_reflection_of_last_handle() self.add_cubic_bezier_curve_to(new_handle, handle, point) - def has_new_path_started(self): + def has_new_path_started(self) -> bool: return self.get_num_points() % self.n_points_per_curve == 1 - def get_last_point(self): + def get_last_point(self) -> np.ndarray: return self.get_points()[-1] - def get_reflection_of_last_handle(self): + def get_reflection_of_last_handle(self) -> np.ndarray: points = self.get_points() return 2 * points[-1] - points[-2] @@ -402,12 +450,16 @@ class VMobject(Mobject): if not self.is_closed(): self.add_line_to(self.get_subpaths()[-1][0]) - def is_closed(self): + def is_closed(self) -> bool: return self.consider_points_equals( self.get_points()[0], self.get_points()[-1] ) - def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True): + def subdivide_sharp_curves( + self, + angle_threshold: float = 30 * DEGREES, + recurse: bool = True + ): vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()] for vmob in vmobs: new_points = [] @@ -425,12 +477,12 @@ class VMobject(Mobject): vmob.set_points(np.vstack(new_points)) return self - def add_points_as_corners(self, points): + def add_points_as_corners(self, points: Iterable[np.ndarray]): for point in points: self.add_line_to(point) return points - def set_points_as_corners(self, points): + def set_points_as_corners(self, points: Iterable[np.ndarray]): nppc = self.n_points_per_curve points = np.array(points) self.set_anchors_and_handles(*[ @@ -439,7 +491,11 @@ class VMobject(Mobject): ]) return self - def set_points_smoothly(self, points, true_smooth=False): + def set_points_smoothly( + self, + points: Iterable[np.ndarray], + true_smooth: bool = False + ): self.set_points_as_corners(points) if true_smooth: self.make_smooth() @@ -447,7 +503,7 @@ class VMobject(Mobject): self.make_approximately_smooth() return self - def change_anchor_mode(self, mode): + def change_anchor_mode(self, mode: str): assert(mode in ("jagged", "approx_smooth", "true_smooth")) nppc = self.n_points_per_curve for submob in self.family_members_with_points(): @@ -492,12 +548,12 @@ class VMobject(Mobject): self.change_anchor_mode("jagged") return self - def add_subpath(self, points): + def add_subpath(self, points: Iterable[np.ndarray]): assert(len(points) % self.n_points_per_curve == 0) self.append_points(points) return self - def append_vectorized_mobject(self, vectorized_mobject): + def append_vectorized_mobject(self, vectorized_mobject: "VMobject"): new_points = list(vectorized_mobject.get_points()) if self.has_new_path_started(): @@ -508,11 +564,11 @@ class VMobject(Mobject): return self # - def consider_points_equals(self, p0, p1): + def consider_points_equals(self, p0: np.ndarray, p1: np.ndarray) -> bool: return get_norm(p1 - p0) < self.tolerance_for_point_equality # Information about the curve - def get_bezier_tuples_from_points(self, points): + def get_bezier_tuples_from_points(self, points: Sequence[np.ndarray]): nppc = self.n_points_per_curve remainder = len(points) % nppc points = points[:len(points) - remainder] @@ -524,7 +580,10 @@ class VMobject(Mobject): def get_bezier_tuples(self): return self.get_bezier_tuples_from_points(self.get_points()) - def get_subpaths_from_points(self, points): + def get_subpaths_from_points( + self, + points: Sequence[np.ndarray] + ) -> list[Sequence[np.ndarray]]: nppc = self.n_points_per_curve diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc] splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality @@ -541,28 +600,28 @@ class VMobject(Mobject): if (i2 - i1) >= nppc ] - def get_subpaths(self): + def get_subpaths(self) -> list[Sequence[np.ndarray]]: return self.get_subpaths_from_points(self.get_points()) - def get_nth_curve_points(self, n): + def get_nth_curve_points(self, n: int) -> np.ndarray: assert(n < self.get_num_curves()) nppc = self.n_points_per_curve return self.get_points()[nppc * n:nppc * (n + 1)] - def get_nth_curve_function(self, n): + def get_nth_curve_function(self, n: int) -> Callable[[float], np.ndarray]: return bezier(self.get_nth_curve_points(n)) - def get_num_curves(self): + def get_num_curves(self) -> int: return self.get_num_points() // self.n_points_per_curve - def quick_point_from_proportion(self, alpha): + def quick_point_from_proportion(self, alpha: float) -> np.ndarray: # Assumes all curves have the same length, so is inaccurate num_curves = self.get_num_curves() n, residue = integer_interpolate(0, num_curves, alpha) curve_func = self.get_nth_curve_function(n) return curve_func(residue) - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: if alpha <= 0: return self.get_start() elif alpha >= 1: @@ -584,7 +643,7 @@ class VMobject(Mobject): residue = inverse_interpolate(partials[i - 1] / full, partials[i] / full, alpha) return self.get_nth_curve_function(i - 1)(residue) - def get_anchors_and_handles(self): + def get_anchors_and_handles(self) -> list[np.ndarray]: """ returns anchors1, handles, anchors2, where (anchors1[i], handles[i], anchors2[i]) @@ -598,14 +657,14 @@ class VMobject(Mobject): for i in range(nppc) ] - def get_start_anchors(self): + def get_start_anchors(self) -> np.ndarray: return self.get_points()[0::self.n_points_per_curve] - def get_end_anchors(self): + def get_end_anchors(self) -> np.ndarray: nppc = self.n_points_per_curve return self.get_points()[nppc - 1::nppc] - def get_anchors(self): + def get_anchors(self) -> np.ndarray: points = self.get_points() if len(points) == 1: return points @@ -614,7 +673,7 @@ class VMobject(Mobject): self.get_end_anchors(), )))) - def get_points_without_null_curves(self, atol=1e-9): + def get_points_without_null_curves(self, atol: float=1e-9) -> np.ndarray: nppc = self.n_points_per_curve points = self.get_points() distinct_curves = reduce(op.or_, [ @@ -623,7 +682,7 @@ class VMobject(Mobject): ]) return points[distinct_curves.repeat(nppc)] - def get_arc_length(self, n_sample_points=None): + def get_arc_length(self, n_sample_points: int | None = None) -> float: if n_sample_points is None: n_sample_points = 4 * self.get_num_curves() + 1 points = np.array([ @@ -634,7 +693,7 @@ class VMobject(Mobject): norms = np.array([get_norm(d) for d in diffs]) return norms.sum() - def get_area_vector(self): + def get_area_vector(self) -> np.ndarray: # Returns a vector whose length is the area bound by # the polygon formed by the anchor points, pointing # in a direction perpendicular to the polygon according @@ -654,7 +713,7 @@ class VMobject(Mobject): sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])), # Add up (x1 + x2)*(y2 - y1) ]) - def get_unit_normal(self, recompute=False): + def get_unit_normal(self, recompute: bool = False) -> np.ndarray: if not recompute: return self.data["unit_normal"][0] @@ -680,7 +739,7 @@ class VMobject(Mobject): return self # Alignment - def align_points(self, vmobject): + def align_points(self, vmobject: "VMobject"): if self.get_num_points() == len(vmobject.get_points()): return @@ -723,7 +782,7 @@ class VMobject(Mobject): vmobject.set_points(np.vstack(new_subpaths2)) return self - def insert_n_curves(self, n, recurse=True): + def insert_n_curves(self, n: int, recurse: bool = True): for mob in self.get_family(recurse): if mob.get_num_curves() > 0: new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) @@ -733,7 +792,7 @@ class VMobject(Mobject): mob.set_points(new_points) return self - def insert_n_curves_to_point_list(self, n, points): + def insert_n_curves_to_point_list(self, n: int, points: np.ndarray): nppc = self.n_points_per_curve if len(points) == 1: return np.repeat(points, nppc * n, 0) @@ -766,7 +825,13 @@ class VMobject(Mobject): new_points += partial_quadratic_bezier_points(group, a1, a2) return np.vstack(new_points) - def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): + def interpolate( + self, + mobject1: "VMobject", + mobject2: "VMobject", + alpha: float, + *args, **kwargs + ): super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) if self.has_fill(): tri1 = mobject1.get_triangulation() @@ -775,7 +840,7 @@ class VMobject(Mobject): self.refresh_triangulation() return self - def pointwise_become_partial(self, vmobject, a, b): + def pointwise_become_partial(self, vmobject: "VMobject", a: float, b: float): assert(isinstance(vmobject, VMobject)) if a <= 0 and b >= 1: self.become(vmobject) @@ -817,7 +882,7 @@ class VMobject(Mobject): self.set_points(new_points) return self - def get_subcurve(self, a, b): + def get_subcurve(self, a: float, b: float) -> "VMobject": vmob = self.copy() vmob.pointwise_become_partial(self, a, b) return vmob @@ -829,7 +894,7 @@ class VMobject(Mobject): mob.needs_new_triangulation = True return self - def get_triangulation(self, normal_vector=None): + def get_triangulation(self, normal_vector: np.ndarray | None = None): # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points @@ -898,25 +963,30 @@ class VMobject(Mobject): return wrapper @triggers_refreshed_triangulation - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): super().set_points(points) return self @triggers_refreshed_triangulation - def set_data(self, data): + def set_data(self, data: dict): super().set_data(data) return self # TODO, how to be smart about tangents here? @triggers_refreshed_triangulation - def apply_function(self, function, make_smooth=False, **kwargs): + def apply_function( + self, + function: Callable[[np.ndarray], np.ndarray], + make_smooth: bool = False, + **kwargs + ): super().apply_function(function, **kwargs) if self.make_smooth_after_applying_functions or make_smooth: self.make_approximately_smooth() return self - def flip(self, *args, **kwargs): - super().flip(*args, **kwargs) + def flip(self, axis: np.ndarray = UP, **kwargs): + super().flip(axis, **kwargs) self.refresh_unit_normal() self.refresh_triangulation() return self @@ -942,20 +1012,20 @@ class VMobject(Mobject): wrapper.refresh_id() return self - def get_fill_shader_wrapper(self): + def get_fill_shader_wrapper(self) -> ShaderWrapper: self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices() self.fill_shader_wrapper.uniforms = self.get_shader_uniforms() self.fill_shader_wrapper.depth_test = self.depth_test return self.fill_shader_wrapper - def get_stroke_shader_wrapper(self): + def get_stroke_shader_wrapper(self) -> ShaderWrapper: self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() self.stroke_shader_wrapper.depth_test = self.depth_test return self.stroke_shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> list[ShaderWrapper]: # Build up data lists fill_shader_wrappers = [] stroke_shader_wrappers = [] @@ -984,13 +1054,13 @@ class VMobject(Mobject): result.append(wrapper) return result - def get_stroke_uniforms(self): + def get_stroke_uniforms(self) -> dict[str, float]: result = dict(super().get_shader_uniforms()) result["joint_type"] = JOINT_TYPE_MAP[self.joint_type] result["flat_stroke"] = float(self.flat_stroke) return result - def get_stroke_shader_data(self): + def get_stroke_shader_data(self) -> np.ndarray: points = self.get_points() if len(self.stroke_data) != len(points): self.stroke_data = resize_array(self.stroke_data, len(points)) @@ -1009,7 +1079,7 @@ class VMobject(Mobject): return self.stroke_data - def get_fill_shader_data(self): + def get_fill_shader_data(self) -> np.ndarray: points = self.get_points() if len(self.fill_data) != len(points): self.fill_data = resize_array(self.fill_data, len(points)) @@ -1025,18 +1095,18 @@ class VMobject(Mobject): self.get_fill_shader_data() self.get_stroke_shader_data() - def get_fill_shader_vert_indices(self): + def get_fill_shader_vert_indices(self) -> np.ndarray: return self.get_triangulation() class VGroup(VMobject): - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: VMobject, **kwargs): if not all([isinstance(m, VMobject) for m in vmobjects]): raise Exception("All submobjects must be of type VMobject") super().__init__(**kwargs) self.add(*vmobjects) - def __add__(self: 'VGroup', other: 'VMobject' or 'VGroup'): + def __add__(self, other: VMobject | "VGroup"): assert(isinstance(other, VMobject)) return self.add(other) @@ -1050,14 +1120,14 @@ class VectorizedPoint(Point, VMobject): "artificial_height": 0.01, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: np.ndarray = ORIGIN, **kwargs): Point.__init__(self, **kwargs) VMobject.__init__(self, **kwargs) self.set_points(np.array([location])) class CurvesAsSubmobjects(VGroup): - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): super().__init__(**kwargs) for tup in vmobject.get_bezier_tuples(): part = VMobject() @@ -1073,7 +1143,7 @@ class DashedVMobject(VMobject): "color": WHITE } - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): super().__init__(**kwargs) num_dashes = self.num_dashes ps_ratio = self.positive_space_ratio