Add Self type to vectorized_mobject.py

This commit is contained in:
Grant Sanderson 2023-01-31 13:43:54 -08:00
parent 50343e9629
commit b58224f6c8
2 changed files with 73 additions and 89 deletions

View file

@ -344,7 +344,7 @@ class Mobject(object):
# Family matters
def __getitem__(self, value: int | slice) -> Mobject:
def __getitem__(self, value: int | slice) -> Self:
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value))
@ -739,7 +739,7 @@ class Mobject(object):
# Creating new Mobjects from this one
def replicate(self, n: int) -> Group:
def replicate(self, n: int) -> Self:
group_class = self.get_group_class()
return group_class(*(self.copy() for _ in range(n)))
@ -752,7 +752,7 @@ class Mobject(object):
group_by_rows: bool = False,
group_by_cols: bool = False,
**kwargs
) -> Group:
) -> Self:
"""
Returns a new mobject containing multiple copies of this one
arranged in a grid

View file

@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple
from typing import Callable, Iterable, Tuple, Any, Self
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
from moderngl.context import Context
@ -128,29 +128,10 @@ class VMobject(Mobject):
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy
def get_family(self, recurse: bool = True) -> list[VMobject]:
return super().get_family(recurse)
def family_members_with_points(self) -> list[VMobject]:
return super().family_members_with_points()
def replicate(self, n: int) -> VGroup:
return super().replicate(n)
def get_grid(self, *args, **kwargs) -> VGroup:
return super().get_grid(*args, **kwargs)
def __getitem__(self, value: int | slice) -> VMobject:
return super().__getitem__(value)
def __iter__(self) -> Iterable[VMobject]:
return super().__iter__()
def add(self, *vmobjects: VMobject):
def add(self, *vmobjects: VMobject) -> Self:
if not all((isinstance(m, VMobject) for m in vmobjects)):
raise Exception("All submobjects must be of type VMobject")
super().add(*vmobjects)
return super().add(*vmobjects)
# Colors
def init_colors(self):
@ -175,7 +156,7 @@ class VMobject(Mobject):
rgba_array: Vect4Array,
name: str | None = None,
recurse: bool = False
):
) -> Self:
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
@ -191,7 +172,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None,
border_width: float | None = None,
recurse: bool = True
):
) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None:
for mob in self.get_family(recurse):
@ -205,7 +186,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None,
background: bool | None = None,
recurse: bool = True
):
) -> Self:
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None:
@ -228,7 +209,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] = BLACK,
width: float | Iterable[float] = 3,
background: bool = True
):
) -> Self:
self.set_stroke(color, width, background=background)
return self
@ -245,7 +226,7 @@ class VMobject(Mobject):
stroke_background: bool = True,
shading: Tuple[float, float, float] | None = None,
recurse: bool = True
):
) -> Self:
for mob in self.get_family(recurse):
if fill_rgba is not None:
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
@ -276,7 +257,7 @@ class VMobject(Mobject):
mob.set_shading(*shading, recurse=False)
return self
def get_style(self):
def get_style(self) -> dict[str, Any]:
data = self.data if self.get_num_points() > 0 else self._data_defaults
return {
"fill_rgba": data['fill_rgba'].copy(),
@ -286,7 +267,7 @@ class VMobject(Mobject):
"shading": self.get_shading(),
}
def match_style(self, vmobject: VMobject, recurse: bool = True):
def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self:
self.set_style(**vmobject.get_style(), recurse=False)
if recurse:
# Does its best to match up submobject lists, and
@ -305,7 +286,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None,
recurse: bool = True
):
) -> Self:
self.set_fill(color, opacity=opacity, recurse=recurse)
self.set_stroke(color, opacity=opacity, recurse=recurse)
return self
@ -314,16 +295,16 @@ class VMobject(Mobject):
self,
opacity: float | Iterable[float] | None,
recurse: bool = True
):
) -> Self:
self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True):
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self:
self.set_uniform(recurse, anti_alias_width=anti_alias_width)
return self
def fade(self, darkness: float = 0.5, recurse: bool = True):
def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self:
mobs = self.get_family() if recurse else [self]
for mob in mobs:
factor = 1.0 - darkness
@ -407,7 +388,7 @@ class VMobject(Mobject):
return self.get_fill_opacity()
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True):
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
mob.uniforms["flat_stroke"] = float(flat_stroke)
return self
@ -415,7 +396,7 @@ class VMobject(Mobject):
def get_flat_stroke(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True):
def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]
return self
@ -428,7 +409,7 @@ class VMobject(Mobject):
anti_alias_width: float = 0,
fill_border_width: float = 0,
recurse: bool=True
):
) -> Self:
super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width)
@ -439,14 +420,14 @@ class VMobject(Mobject):
anti_alias_width: float = 1.0,
fill_border_width: float = 0.5,
recurse: bool=True
):
) -> Self:
super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width)
return self
@Mobject.affects_family_data
def use_winding_fill(self, value: bool = True, recurse: bool = True):
def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self:
for submob in self.get_family(recurse):
submob._use_winding_fill = value
if not value and submob.has_points():
@ -458,7 +439,7 @@ class VMobject(Mobject):
self,
anchors: Vect3Array,
handles: Vect3Array,
):
) -> Self:
assert(len(anchors) == len(handles) + 1)
points = resize_array(self.get_points(), 2 * len(anchors) - 1)
points[0::2] = anchors
@ -466,7 +447,7 @@ class VMobject(Mobject):
self.set_points(points)
return self
def start_new_path(self, point: Vect3):
def start_new_path(self, point: Vect3) -> Self:
# Path ends are signaled by a handle point sitting directly
# on top of the previous anchor
if self.has_points():
@ -481,7 +462,7 @@ class VMobject(Mobject):
handle1: Vect3,
handle2: Vect3,
anchor2: Vect3
):
) -> Self:
self.start_new_path(anchor1)
self.add_cubic_bezier_curve_to(handle1, handle2, anchor2)
return self
@ -491,7 +472,7 @@ class VMobject(Mobject):
handle1: Vect3,
handle2: Vect3,
anchor: Vect3,
):
) -> Self:
"""
Add cubic bezier curve to the path.
"""
@ -513,7 +494,7 @@ class VMobject(Mobject):
self.append_points(quad_approx[1:])
return self
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3):
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self:
self.throw_error_if_no_points()
last_point = self.get_last_point()
if self.consider_points_equal(handle, last_point):
@ -522,14 +503,14 @@ class VMobject(Mobject):
self.append_points([handle, anchor])
return self
def add_line_to(self, point: Vect3):
def add_line_to(self, point: Vect3) -> Self:
self.throw_error_if_no_points()
last_point = self.get_last_point()
alphas = np.linspace(0, 1, 5 if self.long_lines else 3)
self.append_points(outer_interpolate(last_point, point, alphas[1:]))
return self
def add_smooth_curve_to(self, point: Vect3):
def add_smooth_curve_to(self, point: Vect3) -> Self:
if self.has_new_path_started():
self.add_line_to(point)
else:
@ -538,7 +519,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point)
return self
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3):
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self:
self.throw_error_if_no_points()
if self.get_num_points() == 1:
new_handle = handle
@ -559,7 +540,7 @@ class VMobject(Mobject):
points = self.get_points()
return 2 * points[-1] - points[-2]
def close_path(self, smooth: bool = False):
def close_path(self, smooth: bool = False) -> Self:
if self.is_closed():
return self
last_path_start = self.get_subpaths()[-1][0]
@ -577,7 +558,7 @@ class VMobject(Mobject):
self,
tuple_to_subdivisions: Callable,
recurse: bool = True
):
) -> Self:
for vmob in self.get_family(recurse):
if not vmob.has_points():
continue
@ -599,7 +580,7 @@ class VMobject(Mobject):
self,
angle_threshold: float = 30 * DEGREES,
recurse: bool = True
):
) -> Self:
def tuple_to_subdivisions(b0, b1, b2):
angle = angle_between_vectors(b1 - b0, b2 - b1)
return int(angle / angle_threshold)
@ -607,7 +588,7 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1):
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self:
path = self.get_anchors()
def tuple_to_subdivisions(b0, b1, b2):
if line_intersects_path(b0, b1, path):
@ -617,12 +598,12 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self
def add_points_as_corners(self, points: Iterable[Vect3]):
def add_points_as_corners(self, points: Iterable[Vect3]) -> Self:
for point in points:
self.add_line_to(point)
return points
return self
def set_points_as_corners(self, points: Iterable[Vect3]):
def set_points_as_corners(self, points: Iterable[Vect3]) -> Self:
anchors = np.array(points)
handles = 0.5 * (anchors[:-1] + anchors[1:])
self.set_anchors_and_handles(anchors, handles)
@ -632,7 +613,7 @@ class VMobject(Mobject):
self,
points: Iterable[Vect3],
approx: bool = True
):
) -> Self:
self.set_points_as_corners(points)
self.make_smooth(approx=approx)
return self
@ -641,7 +622,7 @@ class VMobject(Mobject):
dots = self.get_joint_products()[::2, 3]
return bool((dots > 1 - 1e-3).all())
def change_anchor_mode(self, mode: str):
def change_anchor_mode(self, mode: str) -> Self:
assert(mode in ("jagged", "approx_smooth", "true_smooth"))
subpaths = self.get_subpaths()
self.clear_points()
@ -664,7 +645,7 @@ class VMobject(Mobject):
self.add_subpath(new_subpath)
return self
def make_smooth(self, approx=False, recurse=True):
def make_smooth(self, approx=False, recurse=True) -> Self:
"""
Edits the path so as to pass smoothly through all
the current anchor points.
@ -679,15 +660,16 @@ class VMobject(Mobject):
submob.change_anchor_mode(mode)
return self
def make_approximately_smooth(self, recurse=True):
def make_approximately_smooth(self, recurse=True) -> Self:
self.make_smooth(approx=True, recurse=recurse)
return self
def make_jagged(self, recurse=True):
def make_jagged(self, recurse=True) -> Self:
for submob in self.get_family(recurse):
submob.change_anchor_mode("jagged")
return self
def add_subpath(self, points: Vect3Array):
def add_subpath(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 1 or len(points) == 0)
if not self.has_points():
self.set_points(points)
@ -697,7 +679,7 @@ class VMobject(Mobject):
self.append_points(points[1:])
return self
def append_vectorized_mobject(self, vmobject: VMobject):
def append_vectorized_mobject(self, vmobject: VMobject) -> Self:
self.add_subpath(vmobject.get_points())
n = vmobject.get_num_points()
self.data[-n:] = vmobject.data
@ -715,7 +697,7 @@ class VMobject(Mobject):
def get_bezier_tuples(self) -> Iterable[Vect3Array]:
return self.get_bezier_tuples_from_points(self.get_points())
def get_subpath_end_indices_from_points(self, points: Vect3Array):
def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray:
atol = self.tolerance_for_point_equality
a0, h, a1 = points[0:-1:2], points[1::2], points[2::2]
# An anchor point is considered the end of a path
@ -731,7 +713,7 @@ class VMobject(Mobject):
is_end[:-1] = is_end[:-1] & ~is_end[1:]
return np.array([2 * n for n, end in enumerate(is_end) if end])
def get_subpath_end_indices(self):
def get_subpath_end_indices(self) -> np.ndarray:
return self.get_subpath_end_indices_from_points(self.get_points())
def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]:
@ -860,7 +842,7 @@ class VMobject(Mobject):
self.data["unit_normal"][:] = normal
return normal
def refresh_unit_normal(self):
def refresh_unit_normal(self) -> Self:
self.get_unit_normal()
return self
@ -870,20 +852,20 @@ class VMobject(Mobject):
axis: Vect3 = OUT,
about_point: Vect3 | None = None,
**kwargs
):
) -> Self:
super().rotate(angle, axis, about_point, **kwargs)
for mob in self.get_family():
mob.refresh_unit_normal()
return self
def ensure_positive_orientation(self, recurse=True):
def ensure_positive_orientation(self, recurse=True) -> Self:
for mob in self.get_family(recurse):
if mob.get_unit_normal()[2] < 0:
mob.reverse_points()
return self
# Alignment
def align_points(self, vmobject: VMobject):
def align_points(self, vmobject: VMobject) -> Self:
winding = self._use_winding_fill and vmobject._use_winding_fill
self.use_winding_fill(winding)
vmobject.use_winding_fill(winding)
@ -940,7 +922,7 @@ class VMobject(Mobject):
mob.get_joint_products()
return self
def invisible_copy(self):
def invisible_copy(self) -> Self:
result = self.copy()
if not result.has_fill() or result.get_num_points() == 0:
return result
@ -948,14 +930,14 @@ class VMobject(Mobject):
result.set_opacity(0)
return result
def insert_n_curves(self, n: int, recurse: bool = True):
def insert_n_curves(self, n: int, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array):
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array:
if len(points) == 1:
return np.repeat(points, 2 * n + 1, 0)
@ -988,7 +970,7 @@ class VMobject(Mobject):
mobject2: VMobject,
alpha: float,
*args, **kwargs
):
) -> Self:
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill() and not self._use_winding_fill:
tri1 = mobject1.get_triangulation()
@ -997,7 +979,7 @@ class VMobject(Mobject):
self.refresh_triangulation()
return self
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float):
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self:
assert(isinstance(vmobject, VMobject))
vm_points = vmobject.get_points()
self.data["joint_product"] = vmobject.data["joint_product"]
@ -1040,12 +1022,12 @@ class VMobject(Mobject):
self.set_points(new_points, refresh_joints=False)
return self
def get_subcurve(self, a: float, b: float) -> VMobject:
def get_subcurve(self, a: float, b: float) -> Self:
vmob = self.copy()
vmob.pointwise_become_partial(self, a, b)
return vmob
def get_outer_vert_indices(self):
def get_outer_vert_indices(self) -> np.ndarray:
"""
Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
"""
@ -1056,12 +1038,12 @@ class VMobject(Mobject):
# Data for shaders that may need refreshing
def refresh_triangulation(self):
def refresh_triangulation(self) -> Self:
for mob in self.get_family():
mob.needs_new_triangulation = True
return self
def get_triangulation(self):
def get_triangulation(self) -> np.ndarray:
# Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader.
# First triangles come directly from the points
@ -1118,12 +1100,12 @@ class VMobject(Mobject):
self.needs_new_triangulation = False
return tri_indices
def refresh_joint_products(self):
def refresh_joint_products(self) -> Self:
for mob in self.get_family():
mob.needs_new_joint_products = True
return self
def get_joint_products(self, refresh: bool = False):
def get_joint_products(self, refresh: bool = False) -> np.ndarray:
"""
The 'joint product' is a 4-vector holding the cross and dot
product between tangent vectors at a joint
@ -1174,10 +1156,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject):
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self:
for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2)
return self
def triggers_refreshed_triangulation(func: Callable):
@wraps(func)
@ -1189,7 +1172,7 @@ class VMobject(Mobject):
return self
return wrapper
def set_points(self, points: Vect3Array, refresh_joints: bool = True):
def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self:
assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points)
self.refresh_triangulation()
@ -1199,13 +1182,13 @@ class VMobject(Mobject):
return self
@triggers_refreshed_triangulation
def append_points(self, points: Vect3Array):
def append_points(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 0)
super().append_points(points)
return self
@triggers_refreshed_triangulation
def reverse_points(self, recurse: bool = True):
def reverse_points(self, recurse: bool = True) -> Self:
# This will reset which anchors are
# considered path ends
for mob in self.get_family(recurse):
@ -1218,7 +1201,7 @@ class VMobject(Mobject):
return self
@triggers_refreshed_triangulation
def set_data(self, data: np.ndarray):
def set_data(self, data: np.ndarray) -> Self:
super().set_data(data)
return self
@ -1229,15 +1212,16 @@ class VMobject(Mobject):
function: Callable[[Vect3], Vect3],
make_smooth: bool = False,
**kwargs
):
) -> Self:
super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth:
self.make_smooth(approx=True)
return self
def apply_points_function(self, *args, **kwargs):
def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs)
self.refresh_joint_products()
return self
# For shaders
def init_shader_data(self, ctx: Context):
@ -1272,7 +1256,7 @@ class VMobject(Mobject):
self.stroke_shader_wrapper,
]
def refresh_shader_wrapper_id(self):
def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized:
return self
for wrapper in self.shader_wrappers:
@ -1343,7 +1327,7 @@ class VGroup(VMobject):
super().__init__(**kwargs)
self.add(*vmobjects)
def __add__(self, other: VMobject | VGroup):
def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject))
return self.add(other)