Add Self type to vectorized_mobject.py

This commit is contained in:
Grant Sanderson 2023-01-31 13:43:54 -08:00
parent 50343e9629
commit b58224f6c8
2 changed files with 73 additions and 89 deletions

View file

@ -344,7 +344,7 @@ class Mobject(object):
# Family matters # Family matters
def __getitem__(self, value: int | slice) -> Mobject: def __getitem__(self, value: int | slice) -> Self:
if isinstance(value, slice): if isinstance(value, slice):
GroupClass = self.get_group_class() GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value)) return GroupClass(*self.split().__getitem__(value))
@ -739,7 +739,7 @@ class Mobject(object):
# Creating new Mobjects from this one # Creating new Mobjects from this one
def replicate(self, n: int) -> Group: def replicate(self, n: int) -> Self:
group_class = self.get_group_class() group_class = self.get_group_class()
return group_class(*(self.copy() for _ in range(n))) return group_class(*(self.copy() for _ in range(n)))
@ -752,7 +752,7 @@ class Mobject(object):
group_by_rows: bool = False, group_by_rows: bool = False,
group_by_cols: bool = False, group_by_cols: bool = False,
**kwargs **kwargs
) -> Group: ) -> Self:
""" """
Returns a new mobject containing multiple copies of this one Returns a new mobject containing multiple copies of this one
arranged in a grid arranged in a grid

View file

@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple from typing import Callable, Iterable, Tuple, Any, Self
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
from moderngl.context import Context from moderngl.context import Context
@ -128,29 +128,10 @@ class VMobject(Mobject):
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type] self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke) self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy def add(self, *vmobjects: VMobject) -> Self:
def get_family(self, recurse: bool = True) -> list[VMobject]:
return super().get_family(recurse)
def family_members_with_points(self) -> list[VMobject]:
return super().family_members_with_points()
def replicate(self, n: int) -> VGroup:
return super().replicate(n)
def get_grid(self, *args, **kwargs) -> VGroup:
return super().get_grid(*args, **kwargs)
def __getitem__(self, value: int | slice) -> VMobject:
return super().__getitem__(value)
def __iter__(self) -> Iterable[VMobject]:
return super().__iter__()
def add(self, *vmobjects: VMobject):
if not all((isinstance(m, VMobject) for m in vmobjects)): if not all((isinstance(m, VMobject) for m in vmobjects)):
raise Exception("All submobjects must be of type VMobject") raise Exception("All submobjects must be of type VMobject")
super().add(*vmobjects) return super().add(*vmobjects)
# Colors # Colors
def init_colors(self): def init_colors(self):
@ -175,7 +156,7 @@ class VMobject(Mobject):
rgba_array: Vect4Array, rgba_array: Vect4Array,
name: str | None = None, name: str | None = None,
recurse: bool = False recurse: bool = False
): ) -> Self:
if name is None: if name is None:
names = ["fill_rgba", "stroke_rgba"] names = ["fill_rgba", "stroke_rgba"]
else: else:
@ -191,7 +172,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
border_width: float | None = None, border_width: float | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None: if border_width is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
@ -205,7 +186,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
background: bool | None = None, background: bool | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None: if width is not None:
@ -228,7 +209,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] = BLACK, color: ManimColor | Iterable[ManimColor] = BLACK,
width: float | Iterable[float] = 3, width: float | Iterable[float] = 3,
background: bool = True background: bool = True
): ) -> Self:
self.set_stroke(color, width, background=background) self.set_stroke(color, width, background=background)
return self return self
@ -245,7 +226,7 @@ class VMobject(Mobject):
stroke_background: bool = True, stroke_background: bool = True,
shading: Tuple[float, float, float] | None = None, shading: Tuple[float, float, float] | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if fill_rgba is not None: if fill_rgba is not None:
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba'])) mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
@ -276,7 +257,7 @@ class VMobject(Mobject):
mob.set_shading(*shading, recurse=False) mob.set_shading(*shading, recurse=False)
return self return self
def get_style(self): def get_style(self) -> dict[str, Any]:
data = self.data if self.get_num_points() > 0 else self._data_defaults data = self.data if self.get_num_points() > 0 else self._data_defaults
return { return {
"fill_rgba": data['fill_rgba'].copy(), "fill_rgba": data['fill_rgba'].copy(),
@ -286,7 +267,7 @@ class VMobject(Mobject):
"shading": self.get_shading(), "shading": self.get_shading(),
} }
def match_style(self, vmobject: VMobject, recurse: bool = True): def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self:
self.set_style(**vmobject.get_style(), recurse=False) self.set_style(**vmobject.get_style(), recurse=False)
if recurse: if recurse:
# Does its best to match up submobject lists, and # Does its best to match up submobject lists, and
@ -305,7 +286,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] | None, color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_fill(color, opacity=opacity, recurse=recurse) self.set_fill(color, opacity=opacity, recurse=recurse)
self.set_stroke(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse)
return self return self
@ -314,16 +295,16 @@ class VMobject(Mobject):
self, self,
opacity: float | Iterable[float] | None, opacity: float | Iterable[float] | None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_fill(opacity=opacity, recurse=recurse) self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse)
return self return self
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True): def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self:
self.set_uniform(recurse, anti_alias_width=anti_alias_width) self.set_uniform(recurse, anti_alias_width=anti_alias_width)
return self return self
def fade(self, darkness: float = 0.5, recurse: bool = True): def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self:
mobs = self.get_family() if recurse else [self] mobs = self.get_family() if recurse else [self]
for mob in mobs: for mob in mobs:
factor = 1.0 - darkness factor = 1.0 - darkness
@ -407,7 +388,7 @@ class VMobject(Mobject):
return self.get_fill_opacity() return self.get_fill_opacity()
return self.get_stroke_opacity() return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.uniforms["flat_stroke"] = float(flat_stroke) mob.uniforms["flat_stroke"] = float(flat_stroke)
return self return self
@ -415,7 +396,7 @@ class VMobject(Mobject):
def get_flat_stroke(self) -> bool: def get_flat_stroke(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0 return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True): def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type] mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]
return self return self
@ -428,7 +409,7 @@ class VMobject(Mobject):
anti_alias_width: float = 0, anti_alias_width: float = 0,
fill_border_width: float = 0, fill_border_width: float = 0,
recurse: bool=True recurse: bool=True
): ) -> Self:
super().apply_depth_test(recurse) super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width) self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width) self.set_fill(border_width=fill_border_width)
@ -439,14 +420,14 @@ class VMobject(Mobject):
anti_alias_width: float = 1.0, anti_alias_width: float = 1.0,
fill_border_width: float = 0.5, fill_border_width: float = 0.5,
recurse: bool=True recurse: bool=True
): ) -> Self:
super().apply_depth_test(recurse) super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width) self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width) self.set_fill(border_width=fill_border_width)
return self return self
@Mobject.affects_family_data @Mobject.affects_family_data
def use_winding_fill(self, value: bool = True, recurse: bool = True): def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self:
for submob in self.get_family(recurse): for submob in self.get_family(recurse):
submob._use_winding_fill = value submob._use_winding_fill = value
if not value and submob.has_points(): if not value and submob.has_points():
@ -458,7 +439,7 @@ class VMobject(Mobject):
self, self,
anchors: Vect3Array, anchors: Vect3Array,
handles: Vect3Array, handles: Vect3Array,
): ) -> Self:
assert(len(anchors) == len(handles) + 1) assert(len(anchors) == len(handles) + 1)
points = resize_array(self.get_points(), 2 * len(anchors) - 1) points = resize_array(self.get_points(), 2 * len(anchors) - 1)
points[0::2] = anchors points[0::2] = anchors
@ -466,7 +447,7 @@ class VMobject(Mobject):
self.set_points(points) self.set_points(points)
return self return self
def start_new_path(self, point: Vect3): def start_new_path(self, point: Vect3) -> Self:
# Path ends are signaled by a handle point sitting directly # Path ends are signaled by a handle point sitting directly
# on top of the previous anchor # on top of the previous anchor
if self.has_points(): if self.has_points():
@ -481,7 +462,7 @@ class VMobject(Mobject):
handle1: Vect3, handle1: Vect3,
handle2: Vect3, handle2: Vect3,
anchor2: Vect3 anchor2: Vect3
): ) -> Self:
self.start_new_path(anchor1) self.start_new_path(anchor1)
self.add_cubic_bezier_curve_to(handle1, handle2, anchor2) self.add_cubic_bezier_curve_to(handle1, handle2, anchor2)
return self return self
@ -491,7 +472,7 @@ class VMobject(Mobject):
handle1: Vect3, handle1: Vect3,
handle2: Vect3, handle2: Vect3,
anchor: Vect3, anchor: Vect3,
): ) -> Self:
""" """
Add cubic bezier curve to the path. Add cubic bezier curve to the path.
""" """
@ -513,7 +494,7 @@ class VMobject(Mobject):
self.append_points(quad_approx[1:]) self.append_points(quad_approx[1:])
return self return self
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3): def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
last_point = self.get_last_point() last_point = self.get_last_point()
if self.consider_points_equal(handle, last_point): if self.consider_points_equal(handle, last_point):
@ -522,14 +503,14 @@ class VMobject(Mobject):
self.append_points([handle, anchor]) self.append_points([handle, anchor])
return self return self
def add_line_to(self, point: Vect3): def add_line_to(self, point: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
last_point = self.get_last_point() last_point = self.get_last_point()
alphas = np.linspace(0, 1, 5 if self.long_lines else 3) alphas = np.linspace(0, 1, 5 if self.long_lines else 3)
self.append_points(outer_interpolate(last_point, point, alphas[1:])) self.append_points(outer_interpolate(last_point, point, alphas[1:]))
return self return self
def add_smooth_curve_to(self, point: Vect3): def add_smooth_curve_to(self, point: Vect3) -> Self:
if self.has_new_path_started(): if self.has_new_path_started():
self.add_line_to(point) self.add_line_to(point)
else: else:
@ -538,7 +519,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point) self.add_quadratic_bezier_curve_to(new_handle, point)
return self return self
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3): def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
if self.get_num_points() == 1: if self.get_num_points() == 1:
new_handle = handle new_handle = handle
@ -559,7 +540,7 @@ class VMobject(Mobject):
points = self.get_points() points = self.get_points()
return 2 * points[-1] - points[-2] return 2 * points[-1] - points[-2]
def close_path(self, smooth: bool = False): def close_path(self, smooth: bool = False) -> Self:
if self.is_closed(): if self.is_closed():
return self return self
last_path_start = self.get_subpaths()[-1][0] last_path_start = self.get_subpaths()[-1][0]
@ -577,7 +558,7 @@ class VMobject(Mobject):
self, self,
tuple_to_subdivisions: Callable, tuple_to_subdivisions: Callable,
recurse: bool = True recurse: bool = True
): ) -> Self:
for vmob in self.get_family(recurse): for vmob in self.get_family(recurse):
if not vmob.has_points(): if not vmob.has_points():
continue continue
@ -599,7 +580,7 @@ class VMobject(Mobject):
self, self,
angle_threshold: float = 30 * DEGREES, angle_threshold: float = 30 * DEGREES,
recurse: bool = True recurse: bool = True
): ) -> Self:
def tuple_to_subdivisions(b0, b1, b2): def tuple_to_subdivisions(b0, b1, b2):
angle = angle_between_vectors(b1 - b0, b2 - b1) angle = angle_between_vectors(b1 - b0, b2 - b1)
return int(angle / angle_threshold) return int(angle / angle_threshold)
@ -607,7 +588,7 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self return self
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1): def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self:
path = self.get_anchors() path = self.get_anchors()
def tuple_to_subdivisions(b0, b1, b2): def tuple_to_subdivisions(b0, b1, b2):
if line_intersects_path(b0, b1, path): if line_intersects_path(b0, b1, path):
@ -617,12 +598,12 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self return self
def add_points_as_corners(self, points: Iterable[Vect3]): def add_points_as_corners(self, points: Iterable[Vect3]) -> Self:
for point in points: for point in points:
self.add_line_to(point) self.add_line_to(point)
return points return self
def set_points_as_corners(self, points: Iterable[Vect3]): def set_points_as_corners(self, points: Iterable[Vect3]) -> Self:
anchors = np.array(points) anchors = np.array(points)
handles = 0.5 * (anchors[:-1] + anchors[1:]) handles = 0.5 * (anchors[:-1] + anchors[1:])
self.set_anchors_and_handles(anchors, handles) self.set_anchors_and_handles(anchors, handles)
@ -632,7 +613,7 @@ class VMobject(Mobject):
self, self,
points: Iterable[Vect3], points: Iterable[Vect3],
approx: bool = True approx: bool = True
): ) -> Self:
self.set_points_as_corners(points) self.set_points_as_corners(points)
self.make_smooth(approx=approx) self.make_smooth(approx=approx)
return self return self
@ -641,7 +622,7 @@ class VMobject(Mobject):
dots = self.get_joint_products()[::2, 3] dots = self.get_joint_products()[::2, 3]
return bool((dots > 1 - 1e-3).all()) return bool((dots > 1 - 1e-3).all())
def change_anchor_mode(self, mode: str): def change_anchor_mode(self, mode: str) -> Self:
assert(mode in ("jagged", "approx_smooth", "true_smooth")) assert(mode in ("jagged", "approx_smooth", "true_smooth"))
subpaths = self.get_subpaths() subpaths = self.get_subpaths()
self.clear_points() self.clear_points()
@ -664,7 +645,7 @@ class VMobject(Mobject):
self.add_subpath(new_subpath) self.add_subpath(new_subpath)
return self return self
def make_smooth(self, approx=False, recurse=True): def make_smooth(self, approx=False, recurse=True) -> Self:
""" """
Edits the path so as to pass smoothly through all Edits the path so as to pass smoothly through all
the current anchor points. the current anchor points.
@ -679,15 +660,16 @@ class VMobject(Mobject):
submob.change_anchor_mode(mode) submob.change_anchor_mode(mode)
return self return self
def make_approximately_smooth(self, recurse=True): def make_approximately_smooth(self, recurse=True) -> Self:
self.make_smooth(approx=True, recurse=recurse) self.make_smooth(approx=True, recurse=recurse)
return self
def make_jagged(self, recurse=True): def make_jagged(self, recurse=True) -> Self:
for submob in self.get_family(recurse): for submob in self.get_family(recurse):
submob.change_anchor_mode("jagged") submob.change_anchor_mode("jagged")
return self return self
def add_subpath(self, points: Vect3Array): def add_subpath(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 1 or len(points) == 0) assert(len(points) % 2 == 1 or len(points) == 0)
if not self.has_points(): if not self.has_points():
self.set_points(points) self.set_points(points)
@ -697,7 +679,7 @@ class VMobject(Mobject):
self.append_points(points[1:]) self.append_points(points[1:])
return self return self
def append_vectorized_mobject(self, vmobject: VMobject): def append_vectorized_mobject(self, vmobject: VMobject) -> Self:
self.add_subpath(vmobject.get_points()) self.add_subpath(vmobject.get_points())
n = vmobject.get_num_points() n = vmobject.get_num_points()
self.data[-n:] = vmobject.data self.data[-n:] = vmobject.data
@ -715,7 +697,7 @@ class VMobject(Mobject):
def get_bezier_tuples(self) -> Iterable[Vect3Array]: def get_bezier_tuples(self) -> Iterable[Vect3Array]:
return self.get_bezier_tuples_from_points(self.get_points()) return self.get_bezier_tuples_from_points(self.get_points())
def get_subpath_end_indices_from_points(self, points: Vect3Array): def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray:
atol = self.tolerance_for_point_equality atol = self.tolerance_for_point_equality
a0, h, a1 = points[0:-1:2], points[1::2], points[2::2] a0, h, a1 = points[0:-1:2], points[1::2], points[2::2]
# An anchor point is considered the end of a path # An anchor point is considered the end of a path
@ -731,7 +713,7 @@ class VMobject(Mobject):
is_end[:-1] = is_end[:-1] & ~is_end[1:] is_end[:-1] = is_end[:-1] & ~is_end[1:]
return np.array([2 * n for n, end in enumerate(is_end) if end]) return np.array([2 * n for n, end in enumerate(is_end) if end])
def get_subpath_end_indices(self): def get_subpath_end_indices(self) -> np.ndarray:
return self.get_subpath_end_indices_from_points(self.get_points()) return self.get_subpath_end_indices_from_points(self.get_points())
def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]: def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]:
@ -860,7 +842,7 @@ class VMobject(Mobject):
self.data["unit_normal"][:] = normal self.data["unit_normal"][:] = normal
return normal return normal
def refresh_unit_normal(self): def refresh_unit_normal(self) -> Self:
self.get_unit_normal() self.get_unit_normal()
return self return self
@ -870,20 +852,20 @@ class VMobject(Mobject):
axis: Vect3 = OUT, axis: Vect3 = OUT,
about_point: Vect3 | None = None, about_point: Vect3 | None = None,
**kwargs **kwargs
): ) -> Self:
super().rotate(angle, axis, about_point, **kwargs) super().rotate(angle, axis, about_point, **kwargs)
for mob in self.get_family(): for mob in self.get_family():
mob.refresh_unit_normal() mob.refresh_unit_normal()
return self return self
def ensure_positive_orientation(self, recurse=True): def ensure_positive_orientation(self, recurse=True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if mob.get_unit_normal()[2] < 0: if mob.get_unit_normal()[2] < 0:
mob.reverse_points() mob.reverse_points()
return self return self
# Alignment # Alignment
def align_points(self, vmobject: VMobject): def align_points(self, vmobject: VMobject) -> Self:
winding = self._use_winding_fill and vmobject._use_winding_fill winding = self._use_winding_fill and vmobject._use_winding_fill
self.use_winding_fill(winding) self.use_winding_fill(winding)
vmobject.use_winding_fill(winding) vmobject.use_winding_fill(winding)
@ -940,7 +922,7 @@ class VMobject(Mobject):
mob.get_joint_products() mob.get_joint_products()
return self return self
def invisible_copy(self): def invisible_copy(self) -> Self:
result = self.copy() result = self.copy()
if not result.has_fill() or result.get_num_points() == 0: if not result.has_fill() or result.get_num_points() == 0:
return result return result
@ -948,14 +930,14 @@ class VMobject(Mobject):
result.set_opacity(0) result.set_opacity(0)
return result return result
def insert_n_curves(self, n: int, recurse: bool = True): def insert_n_curves(self, n: int, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if mob.get_num_curves() > 0: if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
mob.set_points(new_points) mob.set_points(new_points)
return self return self
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array): def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array:
if len(points) == 1: if len(points) == 1:
return np.repeat(points, 2 * n + 1, 0) return np.repeat(points, 2 * n + 1, 0)
@ -988,7 +970,7 @@ class VMobject(Mobject):
mobject2: VMobject, mobject2: VMobject,
alpha: float, alpha: float,
*args, **kwargs *args, **kwargs
): ) -> Self:
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill() and not self._use_winding_fill: if self.has_fill() and not self._use_winding_fill:
tri1 = mobject1.get_triangulation() tri1 = mobject1.get_triangulation()
@ -997,7 +979,7 @@ class VMobject(Mobject):
self.refresh_triangulation() self.refresh_triangulation()
return self return self
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self:
assert(isinstance(vmobject, VMobject)) assert(isinstance(vmobject, VMobject))
vm_points = vmobject.get_points() vm_points = vmobject.get_points()
self.data["joint_product"] = vmobject.data["joint_product"] self.data["joint_product"] = vmobject.data["joint_product"]
@ -1040,12 +1022,12 @@ class VMobject(Mobject):
self.set_points(new_points, refresh_joints=False) self.set_points(new_points, refresh_joints=False)
return self return self
def get_subcurve(self, a: float, b: float) -> VMobject: def get_subcurve(self, a: float, b: float) -> Self:
vmob = self.copy() vmob = self.copy()
vmob.pointwise_become_partial(self, a, b) vmob.pointwise_become_partial(self, a, b)
return vmob return vmob
def get_outer_vert_indices(self): def get_outer_vert_indices(self) -> np.ndarray:
""" """
Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
""" """
@ -1056,12 +1038,12 @@ class VMobject(Mobject):
# Data for shaders that may need refreshing # Data for shaders that may need refreshing
def refresh_triangulation(self): def refresh_triangulation(self) -> Self:
for mob in self.get_family(): for mob in self.get_family():
mob.needs_new_triangulation = True mob.needs_new_triangulation = True
return self return self
def get_triangulation(self): def get_triangulation(self) -> np.ndarray:
# Figure out how to triangulate the interior to know # Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader. # how to send the points as to the vertex shader.
# First triangles come directly from the points # First triangles come directly from the points
@ -1118,12 +1100,12 @@ class VMobject(Mobject):
self.needs_new_triangulation = False self.needs_new_triangulation = False
return tri_indices return tri_indices
def refresh_joint_products(self): def refresh_joint_products(self) -> Self:
for mob in self.get_family(): for mob in self.get_family():
mob.needs_new_joint_products = True mob.needs_new_joint_products = True
return self return self
def get_joint_products(self, refresh: bool = False): def get_joint_products(self, refresh: bool = False) -> np.ndarray:
""" """
The 'joint product' is a 4-vector holding the cross and dot The 'joint product' is a 4-vector holding the cross and dot
product between tangent vectors at a joint product between tangent vectors at a joint
@ -1174,10 +1156,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"] return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject): def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self:
for mob in [self, vmobject1, vmobject2]: for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products() mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2) super().lock_matching_data(vmobject1, vmobject2)
return self
def triggers_refreshed_triangulation(func: Callable): def triggers_refreshed_triangulation(func: Callable):
@wraps(func) @wraps(func)
@ -1189,7 +1172,7 @@ class VMobject(Mobject):
return self return self
return wrapper return wrapper
def set_points(self, points: Vect3Array, refresh_joints: bool = True): def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self:
assert(len(points) == 0 or len(points) % 2 == 1) assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points) super().set_points(points)
self.refresh_triangulation() self.refresh_triangulation()
@ -1199,13 +1182,13 @@ class VMobject(Mobject):
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def append_points(self, points: Vect3Array): def append_points(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 0) assert(len(points) % 2 == 0)
super().append_points(points) super().append_points(points)
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def reverse_points(self, recurse: bool = True): def reverse_points(self, recurse: bool = True) -> Self:
# This will reset which anchors are # This will reset which anchors are
# considered path ends # considered path ends
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
@ -1218,7 +1201,7 @@ class VMobject(Mobject):
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def set_data(self, data: np.ndarray): def set_data(self, data: np.ndarray) -> Self:
super().set_data(data) super().set_data(data)
return self return self
@ -1229,15 +1212,16 @@ class VMobject(Mobject):
function: Callable[[Vect3], Vect3], function: Callable[[Vect3], Vect3],
make_smooth: bool = False, make_smooth: bool = False,
**kwargs **kwargs
): ) -> Self:
super().apply_function(function, **kwargs) super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth: if self.make_smooth_after_applying_functions or make_smooth:
self.make_smooth(approx=True) self.make_smooth(approx=True)
return self return self
def apply_points_function(self, *args, **kwargs): def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs) super().apply_points_function(*args, **kwargs)
self.refresh_joint_products() self.refresh_joint_products()
return self
# For shaders # For shaders
def init_shader_data(self, ctx: Context): def init_shader_data(self, ctx: Context):
@ -1272,7 +1256,7 @@ class VMobject(Mobject):
self.stroke_shader_wrapper, self.stroke_shader_wrapper,
] ]
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized: if not self._shaders_initialized:
return self return self
for wrapper in self.shader_wrappers: for wrapper in self.shader_wrappers:
@ -1343,7 +1327,7 @@ class VGroup(VMobject):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) self.add(*vmobjects)
def __add__(self, other: VMobject | VGroup): def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject)) assert(isinstance(other, VMobject))
return self.add(other) return self.add(other)