chore: add type hints to manimlib.mobject.types

This commit is contained in:
TonyCrane 2022-02-13 18:56:50 +08:00
parent 7f8216bb09
commit 19187ead06
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
5 changed files with 306 additions and 164 deletions

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import moderngl
from manimlib.constants import GREY_C
@ -29,27 +32,31 @@ class DotCloud(PMobject):
],
}
def __init__(self, points=None, **kwargs):
def __init__(self, points: npt.ArrayLike = None, **kwargs):
super().__init__(**kwargs)
if points is not None:
self.set_points(points)
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["radii"] = np.zeros((1, 1))
self.set_radius(self.radius)
def init_uniforms(self):
def init_uniforms(self) -> None:
super().init_uniforms()
self.uniforms["glow_factor"] = self.glow_factor
def to_grid(self, n_rows, n_cols, n_layers=1,
buff_ratio=None,
h_buff_ratio=1.0,
v_buff_ratio=1.0,
d_buff_ratio=1.0,
height=DEFAULT_GRID_HEIGHT,
):
def to_grid(
self,
n_rows: int,
n_cols: int,
n_layers: int = 1,
buff_ratio: float | None = None,
h_buff_ratio: float = 1.0,
v_buff_ratio: float = 1.0,
d_buff_ratio: float = 1.0,
height: float = DEFAULT_GRID_HEIGHT,
):
n_points = n_rows * n_cols * n_layers
points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3))
points[:, 0] = points[:, 0] % n_cols
@ -74,50 +81,55 @@ class DotCloud(PMobject):
self.center()
return self
def set_radii(self, radii):
def set_radii(self, radii: npt.ArrayLike):
n_points = len(self.get_points())
radii = np.array(radii).reshape((len(radii), 1))
self.data["radii"] = resize_preserving_order(radii, n_points)
self.refresh_bounding_box()
return self
def get_radii(self):
def get_radii(self) -> np.ndarray:
return self.data["radii"]
def set_radius(self, radius):
def set_radius(self, radius: float):
self.data["radii"][:] = radius
self.refresh_bounding_box()
return self
def get_radius(self):
def get_radius(self) -> float:
return self.get_radii().max()
def set_glow_factor(self, glow_factor):
def set_glow_factor(self, glow_factor: float) -> None:
self.uniforms["glow_factor"] = glow_factor
def get_glow_factor(self):
def get_glow_factor(self) -> float:
return self.uniforms["glow_factor"]
def compute_bounding_box(self):
def compute_bounding_box(self) -> np.ndarray:
bb = super().compute_bounding_box()
radius = self.get_radius()
bb[0] += np.full((3,), -radius)
bb[2] += np.full((3,), radius)
return bb
def scale(self, scale_factor, scale_radii=True, **kwargs):
def scale(
self,
scale_factor: float | npt.ArrayLike,
scale_radii: bool = True,
**kwargs
):
super().scale(scale_factor, **kwargs)
if scale_radii:
self.set_radii(scale_factor * self.get_radii())
return self
def make_3d(self, reflectiveness=0.5, shadow=0.2):
def make_3d(self, reflectiveness: float = 0.5, shadow: float = 0.2):
self.set_reflectiveness(reflectiveness)
self.set_shadow(shadow)
self.apply_depth_test()
return self
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
shader_data = super().get_shader_data()
self.read_data_to_shader(shader_data, "radius", "radii")
self.read_data_to_shader(shader_data, "color", "rgbas")
@ -125,7 +137,7 @@ class DotCloud(PMobject):
class TrueDot(DotCloud):
def __init__(self, center=ORIGIN, **kwargs):
def __init__(self, center: np.ndarray = ORIGIN, **kwargs):
super().__init__(points=[center], **kwargs)

View file

@ -1,5 +1,6 @@
import numpy as np
from __future__ import annotations
import numpy as np
from PIL import Image
from manimlib.constants import *
@ -21,33 +22,33 @@ class ImageMobject(Mobject):
]
}
def __init__(self, filename, **kwargs):
def __init__(self, filename: str, **kwargs):
self.set_image_path(get_full_raster_image_path(filename))
super().__init__(**kwargs)
def set_image_path(self, path):
def set_image_path(self, path: str) -> None:
self.path = path
self.image = Image.open(path)
self.texture_paths = {"Texture": path}
def init_data(self):
def init_data(self) -> None:
self.data = {
"points": np.array([UL, DL, UR, DR]),
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
"opacity": np.array([[self.opacity]], dtype=np.float32),
}
def init_points(self):
def init_points(self) -> None:
size = self.image.size
self.set_width(2 * size[0] / size[1], stretch=True)
self.set_height(self.height)
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
return self
def point_to_rgb(self, point):
def point_to_rgb(self, point: np.ndarray) -> np.ndarray:
x0, y0 = self.get_corner(UL)[:2]
x1, y1 = self.get_corner(DR)[:2]
x_alpha = inverse_interpolate(x0, x1, point[0])
@ -63,7 +64,7 @@ class ImageMobject(Mobject):
))
return np.array(rgb) / 255
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
shader_data = super().get_shader_data()
self.read_data_to_shader(shader_data, "im_coords", "im_coords")
self.read_data_to_shader(shader_data, "opacity", "opacity")

View file

@ -1,3 +1,10 @@
from __future__ import annotations
from typing import Callable, Sequence, Union
import colour
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.utils.color import color_gradient
@ -6,26 +13,39 @@ from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import resize_array
Color = Union[str, colour.Color, Sequence[float]]
class PMobject(Mobject):
CONFIG = {
"opacity": 1.0,
}
def resize_points(self, size, resize_func=resize_array):
def resize_points(
self,
size: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
# TODO
for key in self.data:
if key == "bounding_box":
continue
if len(self.data[key]) != size:
self.data[key] = resize_array(self.data[key], size)
self.data[key] = resize_func(self.data[key], size)
return self
def set_points(self, points):
def set_points(self, points: npt.ArrayLike):
super().set_points(points)
self.resize_points(len(points))
return self
def add_points(self, points, rgbas=None, color=None, opacity=None):
def add_points(
self,
points: npt.ArrayLike,
rgbas: np.ndarray | None = None,
color: Color | None = None,
opacity: float | None = None
):
"""
points must be a Nx3 numpy array, as must rgbas if it is not None
"""
@ -44,20 +64,20 @@ class PMobject(Mobject):
self.data["rgbas"][-len(new_rgbas):] = new_rgbas
return self
def set_color_by_gradient(self, *colors):
def set_color_by_gradient(self, *colors: Color):
self.data["rgbas"] = np.array(list(map(
color_to_rgba,
color_gradient(colors, self.get_num_points())
)))
return self
def match_colors(self, pmobject):
def match_colors(self, pmobject: "PMobject"):
self.data["rgbas"][:] = resize_with_interpolation(
pmobject.data["rgbas"], self.get_num_points()
)
return self
def filter_out(self, condition):
def filter_out(self, condition: Callable[[np.ndarray], bool]):
for mob in self.family_members_with_points():
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
for key in mob.data:
@ -66,7 +86,7 @@ class PMobject(Mobject):
mob.data[key] = mob.data[key][to_keep]
return self
def sort_points(self, function=lambda p: p[0]):
def sort_points(self, function: Callable[[np.ndarray]] = lambda p: p[0]):
"""
function is any map from R^3 to R
"""
@ -86,11 +106,11 @@ class PMobject(Mobject):
])
return self
def point_from_proportion(self, alpha):
def point_from_proportion(self, alpha: float) -> np.ndarray:
index = alpha * (self.get_num_points() - 1)
return self.get_points()[int(index)]
def pointwise_become_partial(self, pmobject, a, b):
def pointwise_become_partial(self, pmobject: "PMobject", a: float, b: float):
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())
for key in self.data:
@ -101,7 +121,7 @@ class PMobject(Mobject):
class PGroup(PMobject):
def __init__(self, *pmobs, **kwargs):
def __init__(self, *pmobs: PMobject, **kwargs):
if not all([isinstance(m, PMobject) for m in pmobs]):
raise Exception("All submobjects must be of type PMobject")
super().__init__(*pmobs, **kwargs)
@ -112,6 +132,6 @@ class Point(PMobject):
"color": BLACK,
}
def __init__(self, location=ORIGIN, **kwargs):
def __init__(self, location: np.ndarray = ORIGIN, **kwargs):
super().__init__(**kwargs)
self.add_points([location])

View file

@ -1,7 +1,13 @@
import numpy as np
from __future__ import annotations
from typing import Iterable, Callable
import moderngl
import numpy as np
import numpy.typing as npt
from manimlib.constants import *
from manimlib.camera.camera import Camera
from manimlib.mobject.mobject import Mobject
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate
@ -42,7 +48,7 @@ class Surface(Mobject):
super().__init__(**kwargs)
self.compute_triangle_indices()
def uv_func(self, u, v):
def uv_func(self, u: float, v: float) -> tuple[float, float, float]:
# To be implemented in subclasses
return (u, v, 0.0)
@ -85,15 +91,17 @@ class Surface(Mobject):
indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right
self.triangle_indices = indices
def get_triangle_indices(self):
def get_triangle_indices(self) -> np.ndarray:
return self.triangle_indices
def get_surface_points_and_nudged_points(self):
def get_surface_points_and_nudged_points(
self
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
points = self.get_points()
k = len(points) // 3
return points[:k], points[k:2 * k], points[2 * k:]
def get_unit_normals(self):
def get_unit_normals(self) -> np.ndarray:
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
normals = np.cross(
(du_points - s_points) / self.epsilon,
@ -101,7 +109,13 @@ class Surface(Mobject):
)
return normalize_along_axis(normals, 1)
def pointwise_become_partial(self, smobject, a, b, axis=None):
def pointwise_become_partial(
self,
smobject: "Surface",
a: float,
b: float,
axis: np.ndarray | None = None
):
assert(isinstance(smobject, Surface))
if axis is None:
axis = self.prefered_creation_axis
@ -116,7 +130,14 @@ class Surface(Mobject):
]))
return self
def get_partial_points_array(self, points, a, b, resolution, axis):
def get_partial_points_array(
self,
points: np.ndarray,
a: float,
b: float,
resolution: npt.ArrayLike,
axis: int
) -> np.ndarray:
if len(points) == 0:
return points
nu, nv = resolution[:2]
@ -149,7 +170,7 @@ class Surface(Mobject):
).reshape(shape)
return points.reshape((nu * nv, *resolution[2:]))
def sort_faces_back_to_front(self, vect=OUT):
def sort_faces_back_to_front(self, vect: np.ndarray = OUT):
tri_is = self.triangle_indices
indices = list(range(len(tri_is) // 3))
points = self.get_points()
@ -162,13 +183,13 @@ class Surface(Mobject):
tri_is[k::3] = tri_is[k::3][indices]
return self
def always_sort_to_camera(self, camera):
def always_sort_to_camera(self, camera: Camera):
self.add_updater(lambda m: m.sort_faces_back_to_front(
camera.get_location() - self.get_center()
))
# For shaders
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
shader_data = self.get_resized_shader_data_array(len(s_points))
if "points" not in self.locked_data_keys:
@ -178,16 +199,22 @@ class Surface(Mobject):
self.fill_in_shader_color_info(shader_data)
return shader_data
def fill_in_shader_color_info(self, shader_data):
def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray:
self.read_data_to_shader(shader_data, "color", "rgbas")
return shader_data
def get_shader_vert_indices(self):
def get_shader_vert_indices(self) -> np.ndarray:
return self.get_triangle_indices()
class ParametricSurface(Surface):
def __init__(self, uv_func, u_range=(0, 1), v_range=(0, 1), **kwargs):
def __init__(
self,
uv_func: Callable[[float, float], Iterable[float]],
u_range: tuple[float, float] = (0, 1),
v_range: tuple[float, float] = (0, 1),
**kwargs
):
self.passed_uv_func = uv_func
super().__init__(u_range=u_range, v_range=v_range, **kwargs)
@ -200,7 +227,7 @@ class SGroup(Surface):
"resolution": (0, 0),
}
def __init__(self, *parametric_surfaces, **kwargs):
def __init__(self, *parametric_surfaces: Surface, **kwargs):
super().__init__(uv_func=None, **kwargs)
self.add(*parametric_surfaces)
@ -220,7 +247,13 @@ class TexturedSurface(Surface):
]
}
def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs):
def __init__(
self,
uv_surface: Surface,
image_file: str,
dark_image_file: str | None = None,
**kwargs
):
if not isinstance(uv_surface, Surface):
raise Exception("uv_surface must be of type Surface")
# Set texture information
@ -236,10 +269,10 @@ class TexturedSurface(Surface):
self.uv_surface = uv_surface
self.uv_func = uv_surface.uv_func
self.u_range = uv_surface.u_range
self.v_range = uv_surface.v_range
self.resolution = uv_surface.resolution
self.gloss = self.uv_surface.gloss
self.u_range: tuple[float, float] = uv_surface.u_range
self.v_range: tuple[float, float] = uv_surface.v_range
self.resolution: tuple[float, float] = uv_surface.resolution
self.gloss: float = self.uv_surface.gloss
super().__init__(**kwargs)
def init_data(self):
@ -263,12 +296,18 @@ class TexturedSurface(Surface):
def init_colors(self):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
return self
def pointwise_become_partial(self, tsmobject, a, b, axis=1):
def pointwise_become_partial(
self,
tsmobject: "TexturedSurface",
a: float,
b: float,
axis: int = 1
):
super().pointwise_become_partial(tsmobject, a, b, axis)
im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"]
@ -280,7 +319,7 @@ class TexturedSurface(Surface):
)
return self
def fill_in_shader_color_info(self, shader_data):
def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray:
self.read_data_to_shader(shader_data, "opacity", "opacity")
self.read_data_to_shader(shader_data, "im_coords", "im_coords")
return shader_data

View file

@ -1,8 +1,13 @@
import itertools as it
import operator as op
import moderngl
from __future__ import annotations
import operator as op
import itertools as it
from functools import reduce, wraps
from typing import Iterable, Sequence, Callable, Union
import colour
import moderngl
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
@ -29,6 +34,9 @@ from manimlib.utils.space_ops import z_to_vector
from manimlib.shader_wrapper import ShaderWrapper
Color = Union[str, colour.Color, Sequence[float]]
class VMobject(Mobject):
CONFIG = {
"fill_color": None,
@ -105,7 +113,12 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke)
return self
def set_rgba_array(self, rgba_array, name=None, recurse=False):
def set_rgba_array(
self,
rgba_array: npt.ArrayLike,
name: str = None,
recurse: bool = False
):
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
@ -115,11 +128,23 @@ class VMobject(Mobject):
super().set_rgba_array(rgba_array, name, recurse)
return self
def set_fill(self, color=None, opacity=None, recurse=True):
def set_fill(
self,
color: Color | None = None,
opacity: float | None = None,
recurse: bool = True
):
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
return self
def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True):
def set_stroke(
self,
color: Color | None = None,
width: float | npt.ArrayLike | None = None,
opacity: float | None = None,
background: bool | None = None,
recurse: bool = True
):
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None:
@ -135,29 +160,36 @@ class VMobject(Mobject):
mob.draw_stroke_behind_fill = background
return self
def set_backstroke(self, color=BLACK, width=3, background=True):
def set_backstroke(
self,
color: Color = BLACK,
width: float | npt.ArrayLike = 3,
background: bool = True
):
self.set_stroke(color, width, background=background)
return self
def align_stroke_width_data_to_points(self, recurse=True):
def align_stroke_width_data_to_points(self, recurse: bool = True) -> None:
for mob in self.get_family(recurse):
mob.data["stroke_width"] = resize_with_interpolation(
mob.data["stroke_width"], len(mob.get_points())
)
def set_style(self,
fill_color=None,
fill_opacity=None,
fill_rgba=None,
stroke_color=None,
stroke_opacity=None,
stroke_rgba=None,
stroke_width=None,
stroke_background=True,
reflectiveness=None,
gloss=None,
shadow=None,
recurse=True):
def set_style(
self,
fill_color: Color | None = None,
fill_opacity: float | None = None,
fill_rgba: npt.ArrayLike | None = None,
stroke_color: Color | None = None,
stroke_opacity: float | None = None,
stroke_rgba: npt.ArrayLike | None = None,
stroke_width: float | npt.ArrayLike | None = None,
stroke_background: bool = True,
reflectiveness: float | None = None,
gloss: float | None = None,
shadow: float | None = None,
recurse: bool = True
):
if fill_rgba is not None:
self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
else:
@ -201,7 +233,7 @@ class VMobject(Mobject):
"shadow": self.get_shadow(),
}
def match_style(self, vmobject, recurse=True):
def match_style(self, vmobject: "VMobject", recurse: bool = True):
self.set_style(**vmobject.get_style(), recurse=False)
if recurse:
# Does its best to match up submobject lists, and
@ -215,17 +247,17 @@ class VMobject(Mobject):
sm1.match_style(sm2)
return self
def set_color(self, color, recurse=True):
def set_color(self, color: Color, recurse: bool = True):
self.set_fill(color, recurse=recurse)
self.set_stroke(color, recurse=recurse)
return self
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def fade(self, darkness=0.5, recurse=True):
def fade(self, darkness: float = 0.5, recurse: bool = True):
mobs = self.get_family() if recurse else [self]
for mob in mobs:
factor = 1.0 - darkness
@ -239,78 +271,83 @@ class VMobject(Mobject):
)
return self
def get_fill_colors(self):
def get_fill_colors(self) -> list[str]:
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['fill_rgba']
]
def get_fill_opacities(self):
def get_fill_opacities(self) -> np.ndarray:
return self.data['fill_rgba'][:, 3]
def get_stroke_colors(self):
def get_stroke_colors(self) -> list[str]:
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['stroke_rgba']
]
def get_stroke_opacities(self):
def get_stroke_opacities(self) -> np.ndarray:
return self.data['stroke_rgba'][:, 3]
def get_stroke_widths(self):
def get_stroke_widths(self) -> np.ndarray:
return self.data['stroke_width'][:, 0]
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_fill_color(self):
def get_fill_color(self) -> str:
"""
If there are multiple colors (for gradient)
this returns the first one
"""
return self.get_fill_colors()[0]
def get_fill_opacity(self):
def get_fill_opacity(self) -> float:
"""
If there are multiple opacities, this returns the
first
"""
return self.get_fill_opacities()[0]
def get_stroke_color(self):
def get_stroke_color(self) -> str:
return self.get_stroke_colors()[0]
def get_stroke_width(self):
def get_stroke_width(self) -> float | np.ndarray:
return self.get_stroke_widths()[0]
def get_stroke_opacity(self):
def get_stroke_opacity(self) -> float:
return self.get_stroke_opacities()[0]
def get_color(self):
def get_color(self) -> str:
if self.has_fill():
return self.get_fill_color()
return self.get_stroke_color()
def has_stroke(self):
def has_stroke(self) -> bool:
return self.get_stroke_widths().any() and self.get_stroke_opacities().any()
def has_fill(self):
def has_fill(self) -> bool:
return any(self.get_fill_opacities())
def get_opacity(self):
def get_opacity(self) -> float:
if self.has_fill():
return self.get_fill_opacity()
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke=True, recurse=True):
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True):
for mob in self.get_family(recurse):
mob.flat_stroke = flat_stroke
return self
def get_flat_stroke(self):
def get_flat_stroke(self) -> bool:
return self.flat_stroke
# Points
def set_anchors_and_handles(self, anchors1, handles, anchors2):
def set_anchors_and_handles(
self,
anchors1: np.ndarray,
handles: np.ndarray,
anchors2: np.ndarray
):
assert(len(anchors1) == len(handles) == len(anchors2))
nppc = self.n_points_per_curve
new_points = np.zeros((nppc * len(anchors1), self.dim))
@ -320,16 +357,27 @@ class VMobject(Mobject):
self.set_points(new_points)
return self
def start_new_path(self, point):
def start_new_path(self, point: np.ndarray):
assert(self.get_num_points() % self.n_points_per_curve == 0)
self.append_points([point])
return self
def add_cubic_bezier_curve(self, anchor1, handle1, handle2, anchor2):
def add_cubic_bezier_curve(
self,
anchor1: npt.ArrayLike,
handle1: npt.ArrayLike,
handle2: npt.ArrayLike,
anchor2: npt.ArrayLike
):
new_points = get_quadratic_approximation_of_cubic(anchor1, handle1, handle2, anchor2)
self.append_points(new_points)
def add_cubic_bezier_curve_to(self, handle1, handle2, anchor):
def add_cubic_bezier_curve_to(
self,
handle1: npt.ArrayLike,
handle2: npt.ArrayLike,
anchor: npt.ArrayLike
):
"""
Add cubic bezier curve to the path.
"""
@ -342,14 +390,14 @@ class VMobject(Mobject):
else:
self.append_points(quadratic_approx)
def add_quadratic_bezier_curve_to(self, handle, anchor):
def add_quadratic_bezier_curve_to(self, handle: np.ndarray, anchor: np.ndarray):
self.throw_error_if_no_points()
if self.has_new_path_started():
self.append_points([handle, anchor])
else:
self.append_points([self.get_last_point(), handle, anchor])
def add_line_to(self, point):
def add_line_to(self, point: np.ndarray):
end = self.get_points()[-1]
alphas = np.linspace(0, 1, self.n_points_per_curve)
if self.long_lines:
@ -371,7 +419,7 @@ class VMobject(Mobject):
self.append_points(points)
return self
def add_smooth_curve_to(self, point):
def add_smooth_curve_to(self, point: np.ndarray):
if self.has_new_path_started():
self.add_line_to(point)
else:
@ -380,7 +428,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point)
return self
def add_smooth_cubic_curve_to(self, handle, point):
def add_smooth_cubic_curve_to(self, handle: np.ndarray, point: np.ndarray):
self.throw_error_if_no_points()
if self.get_num_points() == 1:
new_handle = self.get_points()[-1]
@ -388,13 +436,13 @@ class VMobject(Mobject):
new_handle = self.get_reflection_of_last_handle()
self.add_cubic_bezier_curve_to(new_handle, handle, point)
def has_new_path_started(self):
def has_new_path_started(self) -> bool:
return self.get_num_points() % self.n_points_per_curve == 1
def get_last_point(self):
def get_last_point(self) -> np.ndarray:
return self.get_points()[-1]
def get_reflection_of_last_handle(self):
def get_reflection_of_last_handle(self) -> np.ndarray:
points = self.get_points()
return 2 * points[-1] - points[-2]
@ -402,12 +450,16 @@ class VMobject(Mobject):
if not self.is_closed():
self.add_line_to(self.get_subpaths()[-1][0])
def is_closed(self):
def is_closed(self) -> bool:
return self.consider_points_equals(
self.get_points()[0], self.get_points()[-1]
)
def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True):
def subdivide_sharp_curves(
self,
angle_threshold: float = 30 * DEGREES,
recurse: bool = True
):
vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()]
for vmob in vmobs:
new_points = []
@ -425,12 +477,12 @@ class VMobject(Mobject):
vmob.set_points(np.vstack(new_points))
return self
def add_points_as_corners(self, points):
def add_points_as_corners(self, points: Iterable[np.ndarray]):
for point in points:
self.add_line_to(point)
return points
def set_points_as_corners(self, points):
def set_points_as_corners(self, points: Iterable[np.ndarray]):
nppc = self.n_points_per_curve
points = np.array(points)
self.set_anchors_and_handles(*[
@ -439,7 +491,11 @@ class VMobject(Mobject):
])
return self
def set_points_smoothly(self, points, true_smooth=False):
def set_points_smoothly(
self,
points: Iterable[np.ndarray],
true_smooth: bool = False
):
self.set_points_as_corners(points)
if true_smooth:
self.make_smooth()
@ -447,7 +503,7 @@ class VMobject(Mobject):
self.make_approximately_smooth()
return self
def change_anchor_mode(self, mode):
def change_anchor_mode(self, mode: str):
assert(mode in ("jagged", "approx_smooth", "true_smooth"))
nppc = self.n_points_per_curve
for submob in self.family_members_with_points():
@ -492,12 +548,12 @@ class VMobject(Mobject):
self.change_anchor_mode("jagged")
return self
def add_subpath(self, points):
def add_subpath(self, points: Iterable[np.ndarray]):
assert(len(points) % self.n_points_per_curve == 0)
self.append_points(points)
return self
def append_vectorized_mobject(self, vectorized_mobject):
def append_vectorized_mobject(self, vectorized_mobject: "VMobject"):
new_points = list(vectorized_mobject.get_points())
if self.has_new_path_started():
@ -508,11 +564,11 @@ class VMobject(Mobject):
return self
#
def consider_points_equals(self, p0, p1):
def consider_points_equals(self, p0: np.ndarray, p1: np.ndarray) -> bool:
return get_norm(p1 - p0) < self.tolerance_for_point_equality
# Information about the curve
def get_bezier_tuples_from_points(self, points):
def get_bezier_tuples_from_points(self, points: Sequence[np.ndarray]):
nppc = self.n_points_per_curve
remainder = len(points) % nppc
points = points[:len(points) - remainder]
@ -524,7 +580,10 @@ class VMobject(Mobject):
def get_bezier_tuples(self):
return self.get_bezier_tuples_from_points(self.get_points())
def get_subpaths_from_points(self, points):
def get_subpaths_from_points(
self,
points: Sequence[np.ndarray]
) -> list[Sequence[np.ndarray]]:
nppc = self.n_points_per_curve
diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc]
splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality
@ -541,28 +600,28 @@ class VMobject(Mobject):
if (i2 - i1) >= nppc
]
def get_subpaths(self):
def get_subpaths(self) -> list[Sequence[np.ndarray]]:
return self.get_subpaths_from_points(self.get_points())
def get_nth_curve_points(self, n):
def get_nth_curve_points(self, n: int) -> np.ndarray:
assert(n < self.get_num_curves())
nppc = self.n_points_per_curve
return self.get_points()[nppc * n:nppc * (n + 1)]
def get_nth_curve_function(self, n):
def get_nth_curve_function(self, n: int) -> Callable[[float], np.ndarray]:
return bezier(self.get_nth_curve_points(n))
def get_num_curves(self):
def get_num_curves(self) -> int:
return self.get_num_points() // self.n_points_per_curve
def quick_point_from_proportion(self, alpha):
def quick_point_from_proportion(self, alpha: float) -> np.ndarray:
# Assumes all curves have the same length, so is inaccurate
num_curves = self.get_num_curves()
n, residue = integer_interpolate(0, num_curves, alpha)
curve_func = self.get_nth_curve_function(n)
return curve_func(residue)
def point_from_proportion(self, alpha):
def point_from_proportion(self, alpha: float) -> np.ndarray:
if alpha <= 0:
return self.get_start()
elif alpha >= 1:
@ -584,7 +643,7 @@ class VMobject(Mobject):
residue = inverse_interpolate(partials[i - 1] / full, partials[i] / full, alpha)
return self.get_nth_curve_function(i - 1)(residue)
def get_anchors_and_handles(self):
def get_anchors_and_handles(self) -> list[np.ndarray]:
"""
returns anchors1, handles, anchors2,
where (anchors1[i], handles[i], anchors2[i])
@ -598,14 +657,14 @@ class VMobject(Mobject):
for i in range(nppc)
]
def get_start_anchors(self):
def get_start_anchors(self) -> np.ndarray:
return self.get_points()[0::self.n_points_per_curve]
def get_end_anchors(self):
def get_end_anchors(self) -> np.ndarray:
nppc = self.n_points_per_curve
return self.get_points()[nppc - 1::nppc]
def get_anchors(self):
def get_anchors(self) -> np.ndarray:
points = self.get_points()
if len(points) == 1:
return points
@ -614,7 +673,7 @@ class VMobject(Mobject):
self.get_end_anchors(),
))))
def get_points_without_null_curves(self, atol=1e-9):
def get_points_without_null_curves(self, atol: float=1e-9) -> np.ndarray:
nppc = self.n_points_per_curve
points = self.get_points()
distinct_curves = reduce(op.or_, [
@ -623,7 +682,7 @@ class VMobject(Mobject):
])
return points[distinct_curves.repeat(nppc)]
def get_arc_length(self, n_sample_points=None):
def get_arc_length(self, n_sample_points: int | None = None) -> float:
if n_sample_points is None:
n_sample_points = 4 * self.get_num_curves() + 1
points = np.array([
@ -634,7 +693,7 @@ class VMobject(Mobject):
norms = np.array([get_norm(d) for d in diffs])
return norms.sum()
def get_area_vector(self):
def get_area_vector(self) -> np.ndarray:
# Returns a vector whose length is the area bound by
# the polygon formed by the anchor points, pointing
# in a direction perpendicular to the polygon according
@ -654,7 +713,7 @@ class VMobject(Mobject):
sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])), # Add up (x1 + x2)*(y2 - y1)
])
def get_unit_normal(self, recompute=False):
def get_unit_normal(self, recompute: bool = False) -> np.ndarray:
if not recompute:
return self.data["unit_normal"][0]
@ -680,7 +739,7 @@ class VMobject(Mobject):
return self
# Alignment
def align_points(self, vmobject):
def align_points(self, vmobject: "VMobject"):
if self.get_num_points() == len(vmobject.get_points()):
return
@ -723,7 +782,7 @@ class VMobject(Mobject):
vmobject.set_points(np.vstack(new_subpaths2))
return self
def insert_n_curves(self, n, recurse=True):
def insert_n_curves(self, n: int, recurse: bool = True):
for mob in self.get_family(recurse):
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
@ -733,7 +792,7 @@ class VMobject(Mobject):
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n, points):
def insert_n_curves_to_point_list(self, n: int, points: np.ndarray):
nppc = self.n_points_per_curve
if len(points) == 1:
return np.repeat(points, nppc * n, 0)
@ -766,7 +825,13 @@ class VMobject(Mobject):
new_points += partial_quadratic_bezier_points(group, a1, a2)
return np.vstack(new_points)
def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs):
def interpolate(
self,
mobject1: "VMobject",
mobject2: "VMobject",
alpha: float,
*args, **kwargs
):
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill():
tri1 = mobject1.get_triangulation()
@ -775,7 +840,7 @@ class VMobject(Mobject):
self.refresh_triangulation()
return self
def pointwise_become_partial(self, vmobject, a, b):
def pointwise_become_partial(self, vmobject: "VMobject", a: float, b: float):
assert(isinstance(vmobject, VMobject))
if a <= 0 and b >= 1:
self.become(vmobject)
@ -817,7 +882,7 @@ class VMobject(Mobject):
self.set_points(new_points)
return self
def get_subcurve(self, a, b):
def get_subcurve(self, a: float, b: float) -> "VMobject":
vmob = self.copy()
vmob.pointwise_become_partial(self, a, b)
return vmob
@ -829,7 +894,7 @@ class VMobject(Mobject):
mob.needs_new_triangulation = True
return self
def get_triangulation(self, normal_vector=None):
def get_triangulation(self, normal_vector: np.ndarray | None = None):
# Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader.
# First triangles come directly from the points
@ -898,25 +963,30 @@ class VMobject(Mobject):
return wrapper
@triggers_refreshed_triangulation
def set_points(self, points):
def set_points(self, points: npt.ArrayLike):
super().set_points(points)
return self
@triggers_refreshed_triangulation
def set_data(self, data):
def set_data(self, data: dict):
super().set_data(data)
return self
# TODO, how to be smart about tangents here?
@triggers_refreshed_triangulation
def apply_function(self, function, make_smooth=False, **kwargs):
def apply_function(
self,
function: Callable[[np.ndarray], np.ndarray],
make_smooth: bool = False,
**kwargs
):
super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth:
self.make_approximately_smooth()
return self
def flip(self, *args, **kwargs):
super().flip(*args, **kwargs)
def flip(self, axis: np.ndarray = UP, **kwargs):
super().flip(axis, **kwargs)
self.refresh_unit_normal()
self.refresh_triangulation()
return self
@ -942,20 +1012,20 @@ class VMobject(Mobject):
wrapper.refresh_id()
return self
def get_fill_shader_wrapper(self):
def get_fill_shader_wrapper(self) -> ShaderWrapper:
self.fill_shader_wrapper.vert_data = self.get_fill_shader_data()
self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices()
self.fill_shader_wrapper.uniforms = self.get_shader_uniforms()
self.fill_shader_wrapper.depth_test = self.depth_test
return self.fill_shader_wrapper
def get_stroke_shader_wrapper(self):
def get_stroke_shader_wrapper(self) -> ShaderWrapper:
self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data()
self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms()
self.stroke_shader_wrapper.depth_test = self.depth_test
return self.stroke_shader_wrapper
def get_shader_wrapper_list(self):
def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
# Build up data lists
fill_shader_wrappers = []
stroke_shader_wrappers = []
@ -984,13 +1054,13 @@ class VMobject(Mobject):
result.append(wrapper)
return result
def get_stroke_uniforms(self):
def get_stroke_uniforms(self) -> dict[str, float]:
result = dict(super().get_shader_uniforms())
result["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
result["flat_stroke"] = float(self.flat_stroke)
return result
def get_stroke_shader_data(self):
def get_stroke_shader_data(self) -> np.ndarray:
points = self.get_points()
if len(self.stroke_data) != len(points):
self.stroke_data = resize_array(self.stroke_data, len(points))
@ -1009,7 +1079,7 @@ class VMobject(Mobject):
return self.stroke_data
def get_fill_shader_data(self):
def get_fill_shader_data(self) -> np.ndarray:
points = self.get_points()
if len(self.fill_data) != len(points):
self.fill_data = resize_array(self.fill_data, len(points))
@ -1025,18 +1095,18 @@ class VMobject(Mobject):
self.get_fill_shader_data()
self.get_stroke_shader_data()
def get_fill_shader_vert_indices(self):
def get_fill_shader_vert_indices(self) -> np.ndarray:
return self.get_triangulation()
class VGroup(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if not all([isinstance(m, VMobject) for m in vmobjects]):
raise Exception("All submobjects must be of type VMobject")
super().__init__(**kwargs)
self.add(*vmobjects)
def __add__(self: 'VGroup', other: 'VMobject' or 'VGroup'):
def __add__(self, other: VMobject | "VGroup"):
assert(isinstance(other, VMobject))
return self.add(other)
@ -1050,14 +1120,14 @@ class VectorizedPoint(Point, VMobject):
"artificial_height": 0.01,
}
def __init__(self, location=ORIGIN, **kwargs):
def __init__(self, location: np.ndarray = ORIGIN, **kwargs):
Point.__init__(self, **kwargs)
VMobject.__init__(self, **kwargs)
self.set_points(np.array([location]))
class CurvesAsSubmobjects(VGroup):
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
for tup in vmobject.get_bezier_tuples():
part = VMobject()
@ -1073,7 +1143,7 @@ class DashedVMobject(VMobject):
"color": WHITE
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
num_dashes = self.num_dashes
ps_ratio = self.positive_space_ratio