Merge branch 'add-self-type' into video-work

This commit is contained in:
Grant Sanderson 2023-01-31 14:26:31 -08:00
commit 4c327cd5d2
14 changed files with 333 additions and 333 deletions

View file

@ -11,7 +11,7 @@ from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, List, Iterable from typing import Callable, List, Iterable, Self
from manimlib.typing import ManimColor, Vect3 from manimlib.typing import ManimColor, Vect3
@ -49,7 +49,7 @@ class AnimatedBoundary(VGroup):
lambda m, dt: self.update_boundary_copies(dt) lambda m, dt: self.update_boundary_copies(dt)
) )
def update_boundary_copies(self, dt: float) -> None: def update_boundary_copies(self, dt: float) -> Self:
# Not actual time, but something which passes at # Not actual time, but something which passes at
# an altered rate to make the implementation below # an altered rate to make the implementation below
# cleaner # cleaner
@ -79,6 +79,7 @@ class AnimatedBoundary(VGroup):
) )
self.total_time += dt self.total_time += dt
return self
def full_family_become_partial( def full_family_become_partial(
self, self,
@ -86,7 +87,7 @@ class AnimatedBoundary(VGroup):
mob2: VMobject, mob2: VMobject,
a: float, a: float,
b: float b: float
): ) -> Self:
family1 = mob1.family_members_with_points() family1 = mob1.family_members_with_points()
family2 = mob2.family_members_with_points() family2 = mob2.family_members_with_points()
for sm1, sm2 in zip(family1, family2): for sm1, sm2 in zip(family1, family2):
@ -118,7 +119,7 @@ class TracedPath(VMobject):
self.traced_points: list[np.ndarray] = [] self.traced_points: list[np.ndarray] = []
self.add_updater(lambda m, dt: m.update_path(dt)) self.add_updater(lambda m, dt: m.update_path(dt))
def update_path(self, dt: float): def update_path(self, dt: float) -> Self:
if dt == 0: if dt == 0:
return self return self
point = self.traced_point_func().copy() point = self.traced_point_func().copy()

View file

@ -21,6 +21,7 @@ from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.dot_cloud import DotCloud from manimlib.mobject.types.dot_cloud import DotCloud
from manimlib.mobject.types.surface import ParametricSurface from manimlib.mobject.types.surface import ParametricSurface
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.simple_functions import binary_search from manimlib.utils.simple_functions import binary_search
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
@ -31,7 +32,7 @@ from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Type, TypeVar from typing import Callable, Iterable, Sequence, Type, TypeVar, Optional, Self
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
@ -235,7 +236,13 @@ class CoordinateSystem(ABC):
""" """
return self.input_to_graph_point(x, graph) return self.input_to_graph_point(x, graph)
def bind_graph_to_func(self, graph, func, jagged=False, get_discontinuities=None): def bind_graph_to_func(
self,
graph: VMobject,
func: Callable[[Vect3], Vect3],
jagged: bool = False,
get_discontinuities: Optional[Callable[[], Vect3]] = None
) -> VMobject:
""" """
Use for graphing functions which might change over time, or change with Use for graphing functions which might change over time, or change with
conditions conditions
@ -659,7 +666,7 @@ class NumberPlane(Axes):
kwargs["buff"] = 0 kwargs["buff"] = 0
return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs) return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs)
def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50): def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50) -> Self:
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
num_curves = mob.get_num_curves() num_curves = mob.get_num_curves()
if num_inserted_curves > num_curves: if num_inserted_curves > num_curves:
@ -698,7 +705,7 @@ class ComplexPlane(NumberPlane):
skip_first: bool = True, skip_first: bool = True,
font_size: int = 36, font_size: int = 36,
**kwargs **kwargs
): ) -> Self:
if numbers is None: if numbers is None:
numbers = self.get_default_coordinate_values(skip_first) numbers = self.get_default_coordinate_values(skip_first)

View file

@ -30,7 +30,7 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterable from typing import Iterable, Self, Optional
from manimlib.typing import ManimColor, Vect3, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -67,7 +67,7 @@ class TipableVMobject(VMobject):
) )
# Adding, Creating, Modifying tips # Adding, Creating, Modifying tips
def add_tip(self, at_start: bool = False, **kwargs): def add_tip(self, at_start: bool = False, **kwargs) -> Self:
""" """
Adds a tip to the TipableVMobject instance, recognising Adds a tip to the TipableVMobject instance, recognising
that the endpoints might need to be switched if it's that the endpoints might need to be switched if it's
@ -112,7 +112,7 @@ class TipableVMobject(VMobject):
tip.shift(anchor - tip.get_tip_point()) tip.shift(anchor - tip.get_tip_point())
return tip return tip
def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool): def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool) -> Self:
if self.get_length() == 0: if self.get_length() == 0:
# Zero length, put_start_and_end_on wouldn't # Zero length, put_start_and_end_on wouldn't
# work # work
@ -127,7 +127,7 @@ class TipableVMobject(VMobject):
self.put_start_and_end_on(start, end) self.put_start_and_end_on(start, end)
return self return self
def asign_tip_attr(self, tip: ArrowTip, at_start: bool): def asign_tip_attr(self, tip: ArrowTip, at_start: bool) -> Self:
if at_start: if at_start:
self.start_tip = tip self.start_tip = tip
else: else:
@ -258,7 +258,7 @@ class Arc(TipableVMobject):
angle = angle_of_vector(self.get_end() - self.get_arc_center()) angle = angle_of_vector(self.get_end() - self.get_arc_center())
return angle % TAU return angle % TAU
def move_arc_center_to(self, point: Vect3): def move_arc_center_to(self, point: Vect3) -> Self:
self.shift(point - self.get_arc_center()) self.shift(point - self.get_arc_center())
return self return self
@ -318,7 +318,7 @@ class Circle(Arc):
dim_to_match: int = 0, dim_to_match: int = 0,
stretch: bool = False, stretch: bool = False,
buff: float = MED_SMALL_BUFF buff: float = MED_SMALL_BUFF
): ) -> Self:
self.replace(mobject, dim_to_match, stretch) self.replace(mobject, dim_to_match, stretch)
self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0) self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0)
self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1) self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1)
@ -475,7 +475,7 @@ class Line(TipableVMobject):
end: Vect3, end: Vect3,
buff: float = 0, buff: float = 0,
path_arc: float = 0 path_arc: float = 0
): ) -> Self:
vect = end - start vect = end - start
dist = get_norm(vect) dist = get_norm(vect)
if np.isclose(dist, 0): if np.isclose(dist, 0):
@ -504,9 +504,10 @@ class Line(TipableVMobject):
self.set_points_as_corners([start, end]) self.set_points_as_corners([start, end])
return self return self
def set_path_arc(self, new_value: float) -> None: def set_path_arc(self, new_value: float) -> Self:
self.path_arc = new_value self.path_arc = new_value
self.init_points() self.init_points()
return self
def set_start_and_end_attrs(self, start: Vect3 | Mobject, end: Vect3 | Mobject): def set_start_and_end_attrs(self, start: Vect3 | Mobject, end: Vect3 | Mobject):
# If either start or end are Mobjects, this # If either start or end are Mobjects, this
@ -541,7 +542,7 @@ class Line(TipableVMobject):
result[:len(point)] = point result[:len(point)] = point
return result return result
def put_start_and_end_on(self, start: Vect3, end: Vect3): def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
curr_start, curr_end = self.get_start_and_end() curr_start, curr_end = self.get_start_and_end()
if np.isclose(curr_start, curr_end).all(): if np.isclose(curr_start, curr_end).all():
# Handle null lines more gracefully # Handle null lines more gracefully
@ -569,7 +570,7 @@ class Line(TipableVMobject):
def get_slope(self) -> float: def get_slope(self) -> float:
return np.tan(self.get_angle()) return np.tan(self.get_angle())
def set_angle(self, angle: float, about_point: Vect3 | None = None): def set_angle(self, angle: float, about_point: Optional[Vect3] = None) -> Self:
if about_point is None: if about_point is None:
about_point = self.get_start() about_point = self.get_start()
self.rotate( self.rotate(
@ -695,13 +696,13 @@ class Arrow(Line):
end: Vect3, end: Vect3,
buff: float = 0, buff: float = 0,
path_arc: float = 0 path_arc: float = 0
): ) -> Self:
super().set_points_by_ends(start, end, buff, path_arc) super().set_points_by_ends(start, end, buff, path_arc)
self.insert_tip_anchor() self.insert_tip_anchor()
self.create_tip_with_stroke_width() self.create_tip_with_stroke_width()
return self return self
def insert_tip_anchor(self): def insert_tip_anchor(self) -> Self:
prev_end = self.get_end() prev_end = self.get_end()
arc_len = self.get_arc_length() arc_len = self.get_arc_length()
tip_len = self.get_stroke_width() * self.width_to_tip_len * self.tip_width_ratio tip_len = self.get_stroke_width() * self.width_to_tip_len * self.tip_width_ratio
@ -716,7 +717,7 @@ class Arrow(Line):
return self return self
@Mobject.affects_data @Mobject.affects_data
def create_tip_with_stroke_width(self): def create_tip_with_stroke_width(self) -> Self:
if self.get_num_points() < 3: if self.get_num_points() < 3:
return self return self
tip_width = self.tip_width_ratio * min( tip_width = self.tip_width_ratio * min(
@ -727,7 +728,7 @@ class Arrow(Line):
self.data['stroke_width'][-3:, 0] = tip_width * np.linspace(1, 0, 3) self.data['stroke_width'][-3:, 0] = tip_width * np.linspace(1, 0, 3)
return self return self
def reset_tip(self): def reset_tip(self) -> Self:
self.set_points_by_ends( self.set_points_by_ends(
self.get_start(), self.get_end(), self.get_start(), self.get_end(),
path_arc=self.path_arc path_arc=self.path_arc
@ -739,13 +740,13 @@ class Arrow(Line):
color: ManimColor | Iterable[ManimColor] | None = None, color: ManimColor | Iterable[ManimColor] | None = None,
width: float | Iterable[float] | None = None, width: float | Iterable[float] | None = None,
*args, **kwargs *args, **kwargs
): ) -> Self:
super().set_stroke(color=color, width=width, *args, **kwargs) super().set_stroke(color=color, width=width, *args, **kwargs)
if self.has_points(): if self.has_points():
self.reset_tip() self.reset_tip()
return self return self
def _handle_scale_side_effects(self, scale_factor: float): def _handle_scale_side_effects(self, scale_factor: float) -> Self:
if scale_factor != 1.0: if scale_factor != 1.0:
self.reset_tip() self.reset_tip()
return self return self
@ -787,7 +788,7 @@ class FillArrow(Line):
end: Vect3, end: Vect3,
buff: float = 0, buff: float = 0,
path_arc: float = 0 path_arc: float = 0
) -> None: ) -> Self:
# Find the right tip length and thickness # Find the right tip length and thickness
vect = end - start vect = end - start
length = max(get_norm(vect), 1e-8) length = max(get_norm(vect), 1e-8)
@ -848,8 +849,9 @@ class FillArrow(Line):
axis=rotate_vector(self.get_unit_vector(), -PI / 2), axis=rotate_vector(self.get_unit_vector(), -PI / 2),
) )
self.shift(start - self.get_start()) self.shift(start - self.get_start())
return self
def reset_points_around_ends(self): def reset_points_around_ends(self) -> Self:
self.set_points_by_ends( self.set_points_by_ends(
self.get_start().copy(), self.get_start().copy(),
self.get_end().copy(), self.get_end().copy(),
@ -864,21 +866,21 @@ class FillArrow(Line):
def get_end(self) -> Vect3: def get_end(self) -> Vect3:
return self.get_points()[self.tip_index] return self.get_points()[self.tip_index]
def put_start_and_end_on(self, start: Vect3, end: Vect3): def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc) self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc)
return self return self
def scale(self, *args, **kwargs): def scale(self, *args, **kwargs) -> Self:
super().scale(*args, **kwargs) super().scale(*args, **kwargs)
self.reset_points_around_ends() self.reset_points_around_ends()
return self return self
def set_thickness(self, thickness: float): def set_thickness(self, thickness: float) -> Self:
self.thickness = thickness self.thickness = thickness
self.reset_points_around_ends() self.reset_points_around_ends()
return self return self
def set_path_arc(self, path_arc: float): def set_path_arc(self, path_arc: float) -> Self:
self.path_arc = path_arc self.path_arc = path_arc
self.reset_points_around_ends() self.reset_points_around_ends()
return self return self
@ -921,7 +923,7 @@ class Polygon(VMobject):
def get_vertices(self) -> Vect3Array: def get_vertices(self) -> Vect3Array:
return self.get_start_anchors() return self.get_start_anchors()
def round_corners(self, radius: float | None = None): def round_corners(self, radius: Optional[float] = None) -> Self:
if radius is None: if radius is None:
verts = self.get_vertices() verts = self.get_vertices()
min_edge_length = min( min_edge_length = min(

View file

@ -18,7 +18,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence, Self
import numpy.typing as npt import numpy.typing as npt
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3 from manimlib.typing import ManimColor, Vect3
@ -129,7 +129,7 @@ class Matrix(VMobject):
v_buff: float, v_buff: float,
h_buff: float, h_buff: float,
aligned_corner: Vect3, aligned_corner: Vect3,
): ) -> Self:
for i, row in enumerate(matrix): for i, row in enumerate(matrix):
for j, elem in enumerate(row): for j, elem in enumerate(row):
mob = matrix[i][j] mob = matrix[i][j]
@ -139,7 +139,7 @@ class Matrix(VMobject):
) )
return self return self
def add_brackets(self, v_buff: float, h_buff: float): def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix) height = len(self.mob_matrix)
brackets = Tex("".join(( brackets = Tex("".join((
R"\left[\begin{array}{c}", R"\left[\begin{array}{c}",
@ -168,13 +168,13 @@ class Matrix(VMobject):
for row in self.mob_matrix for row in self.mob_matrix
]) ])
def set_column_colors(self, *colors: ManimColor): def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns() columns = self.get_columns()
for color, column in zip(colors, columns): for color, column in zip(colors, columns):
column.set_color(color) column.set_color(color)
return self return self
def add_background_to_entries(self): def add_background_to_entries(self) -> Self:
for mob in self.get_entries(): for mob in self.get_entries():
mob.add_background_rectangle() mob.add_background_rectangle()
return self return self

File diff suppressed because it is too large Load diff

View file

@ -11,7 +11,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeVar from typing import TypeVar, Self
from manimlib.typing import ManimColor, Vect3 from manimlib.typing import ManimColor, Vect3
T = TypeVar("T", bound=VMobject) T = TypeVar("T", bound=VMobject)
@ -163,7 +163,7 @@ class DecimalNumber(VMobject):
def get_tex(self): def get_tex(self):
return self.num_string return self.num_string
def set_value(self, number: float | complex): def set_value(self, number: float | complex) -> Self:
move_to_point = self.get_edge_center(self.edge_to_fix) move_to_point = self.get_edge_center(self.edge_to_fix)
style = self.family_members_with_points()[0].get_style() style = self.family_members_with_points()[0].get_style()
self.set_submobjects_from_number(number) self.set_submobjects_from_number(number)
@ -171,14 +171,16 @@ class DecimalNumber(VMobject):
self.set_style(**style) self.set_style(**style)
return self return self
def _handle_scale_side_effects(self, scale_factor: float) -> None: def _handle_scale_side_effects(self, scale_factor: float) -> Self:
self.uniforms["font_size"] = scale_factor * self.uniforms["font_size"] self.uniforms["font_size"] = scale_factor * self.uniforms["font_size"]
return self
def get_value(self) -> float | complex: def get_value(self) -> float | complex:
return self.number return self.number
def increment_value(self, delta_t: float | complex = 1) -> None: def increment_value(self, delta_t: float | complex = 1) -> Self:
self.set_value(self.get_value() + delta_t) self.set_value(self.get_value() + delta_t)
return self
class Integer(DecimalNumber): class Integer(DecimalNumber):

View file

@ -14,7 +14,7 @@ from manimlib.utils.customization import get_customization
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence, Self
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor from manimlib.typing import ManimColor
@ -60,20 +60,20 @@ class BackgroundRectangle(SurroundingRectangle):
) )
self.original_fill_opacity = fill_opacity self.original_fill_opacity = fill_opacity
def pointwise_become_partial(self, mobject: Mobject, a: float, b: float): def pointwise_become_partial(self, mobject: Mobject, a: float, b: float) -> Self:
self.set_fill(opacity=b * self.original_fill_opacity) self.set_fill(opacity=b * self.original_fill_opacity)
return self return self
def set_style_data( def set_style(
self, self,
stroke_color: ManimColor | None = None, stroke_color: ManimColor | None = None,
stroke_width: float | None = None, stroke_width: float | None = None,
fill_color: ManimColor | None = None, fill_color: ManimColor | None = None,
fill_opacity: float | None = None, fill_opacity: float | None = None,
family: bool = True family: bool = True
): ) -> Self:
# Unchangeable style, except for fill_opacity # Unchangeable style, except for fill_opacity
VMobject.set_style_data( VMobject.set_style(
self, self,
stroke_color=BLACK, stroke_color=BLACK,
stroke_width=0, stroke_width=0,

View file

@ -166,7 +166,6 @@ class Cylinder(Surface):
self.scale(self.radius) self.scale(self.radius)
self.set_depth(self.height, stretch=True) self.set_depth(self.height, stretch=True)
self.apply_matrix(z_to_vector(self.axis)) self.apply_matrix(z_to_vector(self.axis))
return self
def uv_func(self, u: float, v: float) -> np.ndarray: def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([np.cos(u), np.sin(u), v]) return np.array([np.cos(u), np.sin(u), v])
@ -186,6 +185,7 @@ class Line3D(Cylinder):
height=get_norm(axis), height=get_norm(axis),
radius=width / 2, radius=width / 2,
axis=axis, axis=axis,
resolution=resolution,
**kwargs **kwargs
) )
self.shift((start + end) / 2) self.shift((start + end) / 2)
@ -376,16 +376,6 @@ class Dodecahedron(VGroup3D):
super().__init__(*pentagons, **style) super().__init__(*pentagons, **style)
# # Rotate those two pentagons by all the axis permuations to fill
# # out the dodecahedron
# Id = np.identity(3)
# for i in range(3):
# perm = [j % 3 for j in range(i, i + 3)]
# for b in [1, -1]:
# matrix = b * np.array([Id[0][perm], Id[1][perm], Id[2][perm]])
# self.add(pentagon1.copy().apply_matrix(matrix, about_point=ORIGIN))
# self.add(pentagon2.copy().apply_matrix(matrix, about_point=ORIGIN))
class Prismify(VGroup3D): class Prismify(VGroup3D):
def __init__(self, vmobject, depth=1.0, direction=IN, **kwargs): def __init__(self, vmobject, depth=1.0, direction=IN, **kwargs):

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy.typing as npt import numpy.typing as npt
from typing import Sequence, Tuple from typing import Sequence, Tuple, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -70,7 +70,7 @@ class DotCloud(PMobject):
v_buff_ratio: float = 1.0, v_buff_ratio: float = 1.0,
d_buff_ratio: float = 1.0, d_buff_ratio: float = 1.0,
height: float = DEFAULT_GRID_HEIGHT, height: float = DEFAULT_GRID_HEIGHT,
): ) -> Self:
n_points = n_rows * n_cols * n_layers n_points = n_rows * n_cols * n_layers
points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3)) points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3))
points[:, 0] = points[:, 0] % n_cols points[:, 0] = points[:, 0] % n_cols
@ -96,7 +96,7 @@ class DotCloud(PMobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def set_radii(self, radii: npt.ArrayLike): def set_radii(self, radii: npt.ArrayLike) -> Self:
n_points = self.get_num_points() n_points = self.get_num_points()
radii = np.array(radii).reshape((len(radii), 1)) radii = np.array(radii).reshape((len(radii), 1))
self.data["radius"][:] = resize_with_interpolation(radii, n_points) self.data["radius"][:] = resize_with_interpolation(radii, n_points)
@ -107,7 +107,7 @@ class DotCloud(PMobject):
return self.data["radius"] return self.data["radius"]
@Mobject.affects_data @Mobject.affects_data
def set_radius(self, radius: float): def set_radius(self, radius: float) -> Self:
data = self.data if self.get_num_points() > 0 else self._data_defaults data = self.data if self.get_num_points() > 0 else self._data_defaults
data["radius"][:] = radius data["radius"][:] = radius
self.refresh_bounding_box() self.refresh_bounding_box()
@ -116,13 +116,14 @@ class DotCloud(PMobject):
def get_radius(self) -> float: def get_radius(self) -> float:
return self.get_radii().max() return self.get_radii().max()
def set_glow_factor(self, glow_factor: float) -> None: def set_glow_factor(self, glow_factor: float) -> Self:
self.uniforms["glow_factor"] = glow_factor self.uniforms["glow_factor"] = glow_factor
return self
def get_glow_factor(self) -> float: def get_glow_factor(self) -> float:
return self.uniforms["glow_factor"] return self.uniforms["glow_factor"]
def compute_bounding_box(self) -> np.ndarray: def compute_bounding_box(self) -> Vect3Array:
bb = super().compute_bounding_box() bb = super().compute_bounding_box()
radius = self.get_radius() radius = self.get_radius()
bb[0] += np.full((3,), -radius) bb[0] += np.full((3,), -radius)
@ -134,7 +135,7 @@ class DotCloud(PMobject):
scale_factor: float | npt.ArrayLike, scale_factor: float | npt.ArrayLike,
scale_radii: bool = True, scale_radii: bool = True,
**kwargs **kwargs
): ) -> Self:
super().scale(scale_factor, **kwargs) super().scale(scale_factor, **kwargs)
if scale_radii: if scale_radii:
self.set_radii(scale_factor * self.get_radii()) self.set_radii(scale_factor * self.get_radii())
@ -145,7 +146,7 @@ class DotCloud(PMobject):
reflectiveness: float = 0.5, reflectiveness: float = 0.5,
gloss: float = 0.1, gloss: float = 0.1,
shadow: float = 0.2 shadow: float = 0.2
): ) -> Self:
self.set_shading(reflectiveness, gloss, shadow) self.set_shading(reflectiveness, gloss, shadow)
self.apply_depth_test() self.apply_depth_test()
return self return self

View file

@ -10,7 +10,7 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable from typing import Callable, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array
@ -28,7 +28,7 @@ class PMobject(Mobject):
rgbas: Vect4Array | None = None, rgbas: Vect4Array | None = None,
color: ManimColor | None = None, color: ManimColor | None = None,
opacity: float | None = None opacity: float | None = None
): ) -> Self:
""" """
points must be a Nx3 numpy array, as must rgbas if it is not None points must be a Nx3 numpy array, as must rgbas if it is not None
""" """
@ -46,13 +46,13 @@ class PMobject(Mobject):
self.data["rgba"][-len(rgbas):] = rgbas self.data["rgba"][-len(rgbas):] = rgbas
return self return self
def add_point(self, point: Vect3, rgba=None, color=None, opacity=None): def add_point(self, point: Vect3, rgba=None, color=None, opacity=None) -> Self:
rgbas = None if rgba is None else [rgba] rgbas = None if rgba is None else [rgba]
self.add_points([point], rgbas, color, opacity) self.add_points([point], rgbas, color, opacity)
return self return self
@Mobject.affects_data @Mobject.affects_data
def set_color_by_gradient(self, *colors: ManimColor): def set_color_by_gradient(self, *colors: ManimColor) -> Self:
self.data["rgba"][:] = np.array(list(map( self.data["rgba"][:] = np.array(list(map(
color_to_rgba, color_to_rgba,
color_gradient(colors, self.get_num_points()) color_gradient(colors, self.get_num_points())
@ -60,20 +60,20 @@ class PMobject(Mobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def match_colors(self, pmobject: PMobject): def match_colors(self, pmobject: PMobject) -> Self:
self.data["rgba"][:] = resize_with_interpolation( self.data["rgba"][:] = resize_with_interpolation(
pmobject.data["rgba"], self.get_num_points() pmobject.data["rgba"], self.get_num_points()
) )
return self return self
@Mobject.affects_data @Mobject.affects_data
def filter_out(self, condition: Callable[[np.ndarray], bool]): def filter_out(self, condition: Callable[[np.ndarray], bool]) -> Self:
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self return self
@Mobject.affects_data @Mobject.affects_data
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]) -> Self:
""" """
function is any map from R^3 to R function is any map from R^3 to R
""" """
@ -85,7 +85,7 @@ class PMobject(Mobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def ingest_submobjects(self): def ingest_submobjects(self) -> Self:
self.data = np.vstack([ self.data = np.vstack([
sm.data for sm in self.get_family() sm.data for sm in self.get_family()
]) ])
@ -96,7 +96,7 @@ class PMobject(Mobject):
return self.get_points()[int(index)] return self.get_points()[int(index)]
@Mobject.affects_data @Mobject.affects_data
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float) -> Self:
lower_index = int(a * pmobject.get_num_points()) lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points())
self.data = pmobject.data[lower_index:upper_index].copy() self.data = pmobject.data[lower_index:upper_index].copy()

View file

@ -17,7 +17,7 @@ from manimlib.utils.space_ops import cross
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Tuple from typing import Callable, Iterable, Sequence, Tuple, Self
from manimlib.camera.camera import Camera from manimlib.camera.camera import Camera
from manimlib.typing import ManimColor, Vect3, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -100,18 +100,19 @@ class Surface(Mobject):
(dv_points - points) / self.epsilon, (dv_points - points) / self.epsilon,
), 1) ), 1)
def apply_points_function(self, *args, **kwargs): def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs) super().apply_points_function(*args, **kwargs)
self.get_unit_normals() self.get_unit_normals()
return self
def compute_triangle_indices(self): def compute_triangle_indices(self) -> np.ndarray:
# TODO, if there is an event which changes # TODO, if there is an event which changes
# the resolution of the surface, make sure # the resolution of the surface, make sure
# this is called. # this is called.
nu, nv = self.resolution nu, nv = self.resolution
if nu == 0 or nv == 0: if nu == 0 or nv == 0:
self.triangle_indices = np.zeros(0, dtype=int) self.triangle_indices = np.zeros(0, dtype=int)
return return self.triangle_indices
index_grid = np.arange(nu * nv).reshape((nu, nv)) index_grid = np.arange(nu * nv).reshape((nu, nv))
indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int) indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int)
indices[0::6] = index_grid[:-1, :-1].flatten() # Top left indices[0::6] = index_grid[:-1, :-1].flatten() # Top left
@ -121,6 +122,7 @@ class Surface(Mobject):
indices[4::6] = index_grid[+1:, :-1].flatten() # Bottom left indices[4::6] = index_grid[+1:, :-1].flatten() # Bottom left
indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right
self.triangle_indices = indices self.triangle_indices = indices
return self.triangle_indices
def get_triangle_indices(self) -> np.ndarray: def get_triangle_indices(self) -> np.ndarray:
return self.triangle_indices return self.triangle_indices
@ -154,7 +156,7 @@ class Surface(Mobject):
a: float, a: float,
b: float, b: float,
axis: int | None = None axis: int | None = None
): ) -> Self:
assert(isinstance(smobject, Surface)) assert(isinstance(smobject, Surface))
if axis is None: if axis is None:
axis = self.prefered_creation_axis axis = self.prefered_creation_axis
@ -211,7 +213,7 @@ class Surface(Mobject):
return points.reshape((nu * nv, *resolution[2:])) return points.reshape((nu * nv, *resolution[2:]))
@Mobject.affects_data @Mobject.affects_data
def sort_faces_back_to_front(self, vect: Vect3 = OUT): def sort_faces_back_to_front(self, vect: Vect3 = OUT) -> Self:
tri_is = self.triangle_indices tri_is = self.triangle_indices
points = self.get_points() points = self.get_points()
@ -221,24 +223,25 @@ class Surface(Mobject):
tri_is[k::3] = tri_is[k::3][indices] tri_is[k::3] = tri_is[k::3][indices]
return self return self
def always_sort_to_camera(self, camera: Camera): def always_sort_to_camera(self, camera: Camera) -> Self:
def updater(surface: Surface): def updater(surface: Surface):
vect = camera.get_location() - surface.get_center() vect = camera.get_location() - surface.get_center()
surface.sort_faces_back_to_front(vect) surface.sort_faces_back_to_front(vect)
self.add_updater(updater) self.add_updater(updater)
return self
def set_clip_plane( def set_clip_plane(
self, self,
vect: Vect3 | None = None, vect: Vect3 | None = None,
threshold: float | None = None threshold: float | None = None
): ) -> Self:
if vect is not None: if vect is not None:
self.uniforms["clip_plane"][:3] = vect self.uniforms["clip_plane"][:3] = vect
if threshold is not None: if threshold is not None:
self.uniforms["clip_plane"][3] = threshold self.uniforms["clip_plane"][3] = threshold
return self return self
def deactivate_clip_plane(self): def deactivate_clip_plane(self) -> Self:
self.uniforms["clip_plane"][:] = 0 self.uniforms["clip_plane"][:] = 0
return self return self
@ -335,7 +338,7 @@ class TexturedSurface(Surface):
self.uniforms["num_textures"] = self.num_textures self.uniforms["num_textures"] = self.num_textures
@Mobject.affects_data @Mobject.affects_data
def set_opacity(self, opacity: float | Iterable[float]): def set_opacity(self, opacity: float | Iterable[float]) -> Self:
op_arr = np.array(listify(opacity)) op_arr = np.array(listify(opacity))
self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data)) self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))
return self return self
@ -345,7 +348,7 @@ class TexturedSurface(Surface):
color: ManimColor | Iterable[ManimColor] | None, color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
if opacity is not None: if opacity is not None:
self.set_opacity(opacity) self.set_opacity(opacity)
return self return self
@ -356,7 +359,7 @@ class TexturedSurface(Surface):
a: float, a: float,
b: float, b: float,
axis: int = 1 axis: int = 1
): ) -> Self:
super().pointwise_become_partial(tsmobject, a, b, axis) super().pointwise_become_partial(tsmobject, a, b, axis)
im_coords = self.data["im_coords"] im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"] im_coords[:] = tsmobject.data["im_coords"]

View file

@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple from typing import Callable, Iterable, Tuple, Any, Self
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
from moderngl.context import Context from moderngl.context import Context
@ -128,29 +128,10 @@ class VMobject(Mobject):
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type] self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke) self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy def add(self, *vmobjects: VMobject) -> Self:
def get_family(self, recurse: bool = True) -> list[VMobject]:
return super().get_family(recurse)
def family_members_with_points(self) -> list[VMobject]:
return super().family_members_with_points()
def replicate(self, n: int) -> VGroup:
return super().replicate(n)
def get_grid(self, *args, **kwargs) -> VGroup:
return super().get_grid(*args, **kwargs)
def __getitem__(self, value: int | slice) -> VMobject:
return super().__getitem__(value)
def __iter__(self) -> Iterable[VMobject]:
return super().__iter__()
def add(self, *vmobjects: VMobject):
if not all((isinstance(m, VMobject) for m in vmobjects)): if not all((isinstance(m, VMobject) for m in vmobjects)):
raise Exception("All submobjects must be of type VMobject") raise Exception("All submobjects must be of type VMobject")
super().add(*vmobjects) return super().add(*vmobjects)
# Colors # Colors
def init_colors(self): def init_colors(self):
@ -175,7 +156,7 @@ class VMobject(Mobject):
rgba_array: Vect4Array, rgba_array: Vect4Array,
name: str | None = None, name: str | None = None,
recurse: bool = False recurse: bool = False
): ) -> Self:
if name is None: if name is None:
names = ["fill_rgba", "stroke_rgba"] names = ["fill_rgba", "stroke_rgba"]
else: else:
@ -191,7 +172,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
border_width: float | None = None, border_width: float | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None: if border_width is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
@ -205,7 +186,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
background: bool | None = None, background: bool | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None: if width is not None:
@ -228,7 +209,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] = BLACK, color: ManimColor | Iterable[ManimColor] = BLACK,
width: float | Iterable[float] = 3, width: float | Iterable[float] = 3,
background: bool = True background: bool = True
): ) -> Self:
self.set_stroke(color, width, background=background) self.set_stroke(color, width, background=background)
return self return self
@ -245,7 +226,7 @@ class VMobject(Mobject):
stroke_background: bool = True, stroke_background: bool = True,
shading: Tuple[float, float, float] | None = None, shading: Tuple[float, float, float] | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if fill_rgba is not None: if fill_rgba is not None:
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba'])) mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
@ -276,7 +257,7 @@ class VMobject(Mobject):
mob.set_shading(*shading, recurse=False) mob.set_shading(*shading, recurse=False)
return self return self
def get_style(self): def get_style(self) -> dict[str, Any]:
data = self.data if self.get_num_points() > 0 else self._data_defaults data = self.data if self.get_num_points() > 0 else self._data_defaults
return { return {
"fill_rgba": data['fill_rgba'].copy(), "fill_rgba": data['fill_rgba'].copy(),
@ -286,7 +267,7 @@ class VMobject(Mobject):
"shading": self.get_shading(), "shading": self.get_shading(),
} }
def match_style(self, vmobject: VMobject, recurse: bool = True): def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self:
self.set_style(**vmobject.get_style(), recurse=False) self.set_style(**vmobject.get_style(), recurse=False)
if recurse: if recurse:
# Does its best to match up submobject lists, and # Does its best to match up submobject lists, and
@ -305,7 +286,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] | None, color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None, opacity: float | Iterable[float] | None = None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_fill(color, opacity=opacity, recurse=recurse) self.set_fill(color, opacity=opacity, recurse=recurse)
self.set_stroke(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse)
return self return self
@ -314,16 +295,16 @@ class VMobject(Mobject):
self, self,
opacity: float | Iterable[float] | None, opacity: float | Iterable[float] | None,
recurse: bool = True recurse: bool = True
): ) -> Self:
self.set_fill(opacity=opacity, recurse=recurse) self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse)
return self return self
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True): def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self:
self.set_uniform(recurse, anti_alias_width=anti_alias_width) self.set_uniform(recurse, anti_alias_width=anti_alias_width)
return self return self
def fade(self, darkness: float = 0.5, recurse: bool = True): def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self:
mobs = self.get_family() if recurse else [self] mobs = self.get_family() if recurse else [self]
for mob in mobs: for mob in mobs:
factor = 1.0 - darkness factor = 1.0 - darkness
@ -407,7 +388,7 @@ class VMobject(Mobject):
return self.get_fill_opacity() return self.get_fill_opacity()
return self.get_stroke_opacity() return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.uniforms["flat_stroke"] = float(flat_stroke) mob.uniforms["flat_stroke"] = float(flat_stroke)
return self return self
@ -415,7 +396,7 @@ class VMobject(Mobject):
def get_flat_stroke(self) -> bool: def get_flat_stroke(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0 return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True): def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type] mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]
return self return self
@ -428,7 +409,7 @@ class VMobject(Mobject):
anti_alias_width: float = 0, anti_alias_width: float = 0,
fill_border_width: float = 0, fill_border_width: float = 0,
recurse: bool=True recurse: bool=True
): ) -> Self:
super().apply_depth_test(recurse) super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width) self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width) self.set_fill(border_width=fill_border_width)
@ -439,14 +420,14 @@ class VMobject(Mobject):
anti_alias_width: float = 1.0, anti_alias_width: float = 1.0,
fill_border_width: float = 0.5, fill_border_width: float = 0.5,
recurse: bool=True recurse: bool=True
): ) -> Self:
super().apply_depth_test(recurse) super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width) self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width) self.set_fill(border_width=fill_border_width)
return self return self
@Mobject.affects_family_data @Mobject.affects_family_data
def use_winding_fill(self, value: bool = True, recurse: bool = True): def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self:
for submob in self.get_family(recurse): for submob in self.get_family(recurse):
submob._use_winding_fill = value submob._use_winding_fill = value
if not value and submob.has_points(): if not value and submob.has_points():
@ -458,7 +439,7 @@ class VMobject(Mobject):
self, self,
anchors: Vect3Array, anchors: Vect3Array,
handles: Vect3Array, handles: Vect3Array,
): ) -> Self:
assert(len(anchors) == len(handles) + 1) assert(len(anchors) == len(handles) + 1)
points = resize_array(self.get_points(), 2 * len(anchors) - 1) points = resize_array(self.get_points(), 2 * len(anchors) - 1)
points[0::2] = anchors points[0::2] = anchors
@ -466,7 +447,7 @@ class VMobject(Mobject):
self.set_points(points) self.set_points(points)
return self return self
def start_new_path(self, point: Vect3): def start_new_path(self, point: Vect3) -> Self:
# Path ends are signaled by a handle point sitting directly # Path ends are signaled by a handle point sitting directly
# on top of the previous anchor # on top of the previous anchor
if self.has_points(): if self.has_points():
@ -481,7 +462,7 @@ class VMobject(Mobject):
handle1: Vect3, handle1: Vect3,
handle2: Vect3, handle2: Vect3,
anchor2: Vect3 anchor2: Vect3
): ) -> Self:
self.start_new_path(anchor1) self.start_new_path(anchor1)
self.add_cubic_bezier_curve_to(handle1, handle2, anchor2) self.add_cubic_bezier_curve_to(handle1, handle2, anchor2)
return self return self
@ -491,7 +472,7 @@ class VMobject(Mobject):
handle1: Vect3, handle1: Vect3,
handle2: Vect3, handle2: Vect3,
anchor: Vect3, anchor: Vect3,
): ) -> Self:
""" """
Add cubic bezier curve to the path. Add cubic bezier curve to the path.
""" """
@ -513,7 +494,7 @@ class VMobject(Mobject):
self.append_points(quad_approx[1:]) self.append_points(quad_approx[1:])
return self return self
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3): def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
last_point = self.get_last_point() last_point = self.get_last_point()
if self.consider_points_equal(handle, last_point): if self.consider_points_equal(handle, last_point):
@ -522,14 +503,14 @@ class VMobject(Mobject):
self.append_points([handle, anchor]) self.append_points([handle, anchor])
return self return self
def add_line_to(self, point: Vect3): def add_line_to(self, point: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
last_point = self.get_last_point() last_point = self.get_last_point()
alphas = np.linspace(0, 1, 5 if self.long_lines else 3) alphas = np.linspace(0, 1, 5 if self.long_lines else 3)
self.append_points(outer_interpolate(last_point, point, alphas[1:])) self.append_points(outer_interpolate(last_point, point, alphas[1:]))
return self return self
def add_smooth_curve_to(self, point: Vect3): def add_smooth_curve_to(self, point: Vect3) -> Self:
if self.has_new_path_started(): if self.has_new_path_started():
self.add_line_to(point) self.add_line_to(point)
else: else:
@ -538,7 +519,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point) self.add_quadratic_bezier_curve_to(new_handle, point)
return self return self
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3): def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self:
self.throw_error_if_no_points() self.throw_error_if_no_points()
if self.get_num_points() == 1: if self.get_num_points() == 1:
new_handle = handle new_handle = handle
@ -559,7 +540,7 @@ class VMobject(Mobject):
points = self.get_points() points = self.get_points()
return 2 * points[-1] - points[-2] return 2 * points[-1] - points[-2]
def close_path(self, smooth: bool = False): def close_path(self, smooth: bool = False) -> Self:
if self.is_closed(): if self.is_closed():
return self return self
last_path_start = self.get_subpaths()[-1][0] last_path_start = self.get_subpaths()[-1][0]
@ -577,7 +558,7 @@ class VMobject(Mobject):
self, self,
tuple_to_subdivisions: Callable, tuple_to_subdivisions: Callable,
recurse: bool = True recurse: bool = True
): ) -> Self:
for vmob in self.get_family(recurse): for vmob in self.get_family(recurse):
if not vmob.has_points(): if not vmob.has_points():
continue continue
@ -599,7 +580,7 @@ class VMobject(Mobject):
self, self,
angle_threshold: float = 30 * DEGREES, angle_threshold: float = 30 * DEGREES,
recurse: bool = True recurse: bool = True
): ) -> Self:
def tuple_to_subdivisions(b0, b1, b2): def tuple_to_subdivisions(b0, b1, b2):
angle = angle_between_vectors(b1 - b0, b2 - b1) angle = angle_between_vectors(b1 - b0, b2 - b1)
return int(angle / angle_threshold) return int(angle / angle_threshold)
@ -607,7 +588,7 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self return self
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1): def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self:
path = self.get_anchors() path = self.get_anchors()
def tuple_to_subdivisions(b0, b1, b2): def tuple_to_subdivisions(b0, b1, b2):
if line_intersects_path(b0, b1, path): if line_intersects_path(b0, b1, path):
@ -617,12 +598,12 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self return self
def add_points_as_corners(self, points: Iterable[Vect3]): def add_points_as_corners(self, points: Iterable[Vect3]) -> Self:
for point in points: for point in points:
self.add_line_to(point) self.add_line_to(point)
return points return self
def set_points_as_corners(self, points: Iterable[Vect3]): def set_points_as_corners(self, points: Iterable[Vect3]) -> Self:
anchors = np.array(points) anchors = np.array(points)
handles = 0.5 * (anchors[:-1] + anchors[1:]) handles = 0.5 * (anchors[:-1] + anchors[1:])
self.set_anchors_and_handles(anchors, handles) self.set_anchors_and_handles(anchors, handles)
@ -632,7 +613,7 @@ class VMobject(Mobject):
self, self,
points: Iterable[Vect3], points: Iterable[Vect3],
approx: bool = True approx: bool = True
): ) -> Self:
self.set_points_as_corners(points) self.set_points_as_corners(points)
self.make_smooth(approx=approx) self.make_smooth(approx=approx)
return self return self
@ -641,7 +622,7 @@ class VMobject(Mobject):
dots = self.get_joint_products()[::2, 3] dots = self.get_joint_products()[::2, 3]
return bool((dots > 1 - 1e-3).all()) return bool((dots > 1 - 1e-3).all())
def change_anchor_mode(self, mode: str): def change_anchor_mode(self, mode: str) -> Self:
assert(mode in ("jagged", "approx_smooth", "true_smooth")) assert(mode in ("jagged", "approx_smooth", "true_smooth"))
subpaths = self.get_subpaths() subpaths = self.get_subpaths()
self.clear_points() self.clear_points()
@ -664,7 +645,7 @@ class VMobject(Mobject):
self.add_subpath(new_subpath) self.add_subpath(new_subpath)
return self return self
def make_smooth(self, approx=False, recurse=True): def make_smooth(self, approx=False, recurse=True) -> Self:
""" """
Edits the path so as to pass smoothly through all Edits the path so as to pass smoothly through all
the current anchor points. the current anchor points.
@ -679,15 +660,16 @@ class VMobject(Mobject):
submob.change_anchor_mode(mode) submob.change_anchor_mode(mode)
return self return self
def make_approximately_smooth(self, recurse=True): def make_approximately_smooth(self, recurse=True) -> Self:
self.make_smooth(approx=True, recurse=recurse) self.make_smooth(approx=True, recurse=recurse)
return self
def make_jagged(self, recurse=True): def make_jagged(self, recurse=True) -> Self:
for submob in self.get_family(recurse): for submob in self.get_family(recurse):
submob.change_anchor_mode("jagged") submob.change_anchor_mode("jagged")
return self return self
def add_subpath(self, points: Vect3Array): def add_subpath(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 1 or len(points) == 0) assert(len(points) % 2 == 1 or len(points) == 0)
if not self.has_points(): if not self.has_points():
self.set_points(points) self.set_points(points)
@ -697,7 +679,7 @@ class VMobject(Mobject):
self.append_points(points[1:]) self.append_points(points[1:])
return self return self
def append_vectorized_mobject(self, vmobject: VMobject): def append_vectorized_mobject(self, vmobject: VMobject) -> Self:
self.add_subpath(vmobject.get_points()) self.add_subpath(vmobject.get_points())
n = vmobject.get_num_points() n = vmobject.get_num_points()
self.data[-n:] = vmobject.data self.data[-n:] = vmobject.data
@ -715,7 +697,7 @@ class VMobject(Mobject):
def get_bezier_tuples(self) -> Iterable[Vect3Array]: def get_bezier_tuples(self) -> Iterable[Vect3Array]:
return self.get_bezier_tuples_from_points(self.get_points()) return self.get_bezier_tuples_from_points(self.get_points())
def get_subpath_end_indices_from_points(self, points: Vect3Array): def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray:
atol = self.tolerance_for_point_equality atol = self.tolerance_for_point_equality
a0, h, a1 = points[0:-1:2], points[1::2], points[2::2] a0, h, a1 = points[0:-1:2], points[1::2], points[2::2]
# An anchor point is considered the end of a path # An anchor point is considered the end of a path
@ -731,7 +713,7 @@ class VMobject(Mobject):
is_end[:-1] = is_end[:-1] & ~is_end[1:] is_end[:-1] = is_end[:-1] & ~is_end[1:]
return np.array([2 * n for n, end in enumerate(is_end) if end]) return np.array([2 * n for n, end in enumerate(is_end) if end])
def get_subpath_end_indices(self): def get_subpath_end_indices(self) -> np.ndarray:
return self.get_subpath_end_indices_from_points(self.get_points()) return self.get_subpath_end_indices_from_points(self.get_points())
def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]: def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]:
@ -860,7 +842,7 @@ class VMobject(Mobject):
self.data["unit_normal"][:] = normal self.data["unit_normal"][:] = normal
return normal return normal
def refresh_unit_normal(self): def refresh_unit_normal(self) -> Self:
self.get_unit_normal() self.get_unit_normal()
return self return self
@ -870,20 +852,20 @@ class VMobject(Mobject):
axis: Vect3 = OUT, axis: Vect3 = OUT,
about_point: Vect3 | None = None, about_point: Vect3 | None = None,
**kwargs **kwargs
): ) -> Self:
super().rotate(angle, axis, about_point, **kwargs) super().rotate(angle, axis, about_point, **kwargs)
for mob in self.get_family(): for mob in self.get_family():
mob.refresh_unit_normal() mob.refresh_unit_normal()
return self return self
def ensure_positive_orientation(self, recurse=True): def ensure_positive_orientation(self, recurse=True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if mob.get_unit_normal()[2] < 0: if mob.get_unit_normal()[2] < 0:
mob.reverse_points() mob.reverse_points()
return self return self
# Alignment # Alignment
def align_points(self, vmobject: VMobject): def align_points(self, vmobject: VMobject) -> Self:
winding = self._use_winding_fill and vmobject._use_winding_fill winding = self._use_winding_fill and vmobject._use_winding_fill
self.use_winding_fill(winding) self.use_winding_fill(winding)
vmobject.use_winding_fill(winding) vmobject.use_winding_fill(winding)
@ -940,7 +922,7 @@ class VMobject(Mobject):
mob.get_joint_products() mob.get_joint_products()
return self return self
def invisible_copy(self): def invisible_copy(self) -> Self:
result = self.copy() result = self.copy()
if not result.has_fill() or result.get_num_points() == 0: if not result.has_fill() or result.get_num_points() == 0:
return result return result
@ -948,14 +930,14 @@ class VMobject(Mobject):
result.set_opacity(0) result.set_opacity(0)
return result return result
def insert_n_curves(self, n: int, recurse: bool = True): def insert_n_curves(self, n: int, recurse: bool = True) -> Self:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
if mob.get_num_curves() > 0: if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
mob.set_points(new_points) mob.set_points(new_points)
return self return self
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array): def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array:
if len(points) == 1: if len(points) == 1:
return np.repeat(points, 2 * n + 1, 0) return np.repeat(points, 2 * n + 1, 0)
@ -988,7 +970,7 @@ class VMobject(Mobject):
mobject2: VMobject, mobject2: VMobject,
alpha: float, alpha: float,
*args, **kwargs *args, **kwargs
): ) -> Self:
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill() and not self._use_winding_fill: if self.has_fill() and not self._use_winding_fill:
tri1 = mobject1.get_triangulation() tri1 = mobject1.get_triangulation()
@ -997,7 +979,7 @@ class VMobject(Mobject):
self.refresh_triangulation() self.refresh_triangulation()
return self return self
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self:
assert(isinstance(vmobject, VMobject)) assert(isinstance(vmobject, VMobject))
vm_points = vmobject.get_points() vm_points = vmobject.get_points()
self.data["joint_product"] = vmobject.data["joint_product"] self.data["joint_product"] = vmobject.data["joint_product"]
@ -1040,12 +1022,12 @@ class VMobject(Mobject):
self.set_points(new_points, refresh_joints=False) self.set_points(new_points, refresh_joints=False)
return self return self
def get_subcurve(self, a: float, b: float) -> VMobject: def get_subcurve(self, a: float, b: float) -> Self:
vmob = self.copy() vmob = self.copy()
vmob.pointwise_become_partial(self, a, b) vmob.pointwise_become_partial(self, a, b)
return vmob return vmob
def get_outer_vert_indices(self): def get_outer_vert_indices(self) -> np.ndarray:
""" """
Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
""" """
@ -1056,12 +1038,12 @@ class VMobject(Mobject):
# Data for shaders that may need refreshing # Data for shaders that may need refreshing
def refresh_triangulation(self): def refresh_triangulation(self) -> Self:
for mob in self.get_family(): for mob in self.get_family():
mob.needs_new_triangulation = True mob.needs_new_triangulation = True
return self return self
def get_triangulation(self): def get_triangulation(self) -> np.ndarray:
# Figure out how to triangulate the interior to know # Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader. # how to send the points as to the vertex shader.
# First triangles come directly from the points # First triangles come directly from the points
@ -1118,12 +1100,12 @@ class VMobject(Mobject):
self.needs_new_triangulation = False self.needs_new_triangulation = False
return tri_indices return tri_indices
def refresh_joint_products(self): def refresh_joint_products(self) -> Self:
for mob in self.get_family(): for mob in self.get_family():
mob.needs_new_joint_products = True mob.needs_new_joint_products = True
return self return self
def get_joint_products(self, refresh: bool = False): def get_joint_products(self, refresh: bool = False) -> np.ndarray:
""" """
The 'joint product' is a 4-vector holding the cross and dot The 'joint product' is a 4-vector holding the cross and dot
product between tangent vectors at a joint product between tangent vectors at a joint
@ -1174,10 +1156,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"] return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject): def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self:
for mob in [self, vmobject1, vmobject2]: for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products() mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2) super().lock_matching_data(vmobject1, vmobject2)
return self
def triggers_refreshed_triangulation(func: Callable): def triggers_refreshed_triangulation(func: Callable):
@wraps(func) @wraps(func)
@ -1189,7 +1172,7 @@ class VMobject(Mobject):
return self return self
return wrapper return wrapper
def set_points(self, points: Vect3Array, refresh_joints: bool = True): def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self:
assert(len(points) == 0 or len(points) % 2 == 1) assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points) super().set_points(points)
self.refresh_triangulation() self.refresh_triangulation()
@ -1199,13 +1182,13 @@ class VMobject(Mobject):
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def append_points(self, points: Vect3Array): def append_points(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 0) assert(len(points) % 2 == 0)
super().append_points(points) super().append_points(points)
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def reverse_points(self, recurse: bool = True): def reverse_points(self, recurse: bool = True) -> Self:
# This will reset which anchors are # This will reset which anchors are
# considered path ends # considered path ends
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
@ -1218,7 +1201,7 @@ class VMobject(Mobject):
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def set_data(self, data: np.ndarray): def set_data(self, data: np.ndarray) -> Self:
super().set_data(data) super().set_data(data)
return self return self
@ -1229,15 +1212,16 @@ class VMobject(Mobject):
function: Callable[[Vect3], Vect3], function: Callable[[Vect3], Vect3],
make_smooth: bool = False, make_smooth: bool = False,
**kwargs **kwargs
): ) -> Self:
super().apply_function(function, **kwargs) super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth: if self.make_smooth_after_applying_functions or make_smooth:
self.make_smooth(approx=True) self.make_smooth(approx=True)
return self return self
def apply_points_function(self, *args, **kwargs): def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs) super().apply_points_function(*args, **kwargs)
self.refresh_joint_products() self.refresh_joint_products()
return self
def set_animating_status(self, is_animating: bool, recurse: bool = True): def set_animating_status(self, is_animating: bool, recurse: bool = True):
super().set_animating_status(is_animating, recurse) super().set_animating_status(is_animating, recurse)
@ -1281,7 +1265,7 @@ class VMobject(Mobject):
self.stroke_shader_wrapper, self.stroke_shader_wrapper,
] ]
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized: if not self._shaders_initialized:
return self return self
for wrapper in self.shader_wrappers: for wrapper in self.shader_wrappers:
@ -1348,7 +1332,7 @@ class VGroup(VMobject):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) self.add(*vmobjects)
def __add__(self, other: VMobject | VGroup): def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject)) assert(isinstance(other, VMobject))
return self.add(other) return self.add(other)

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from typing import Self
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.utils.iterables import listify from manimlib.utils.iterables import listify
@ -36,7 +36,7 @@ class ValueTracker(Mobject):
return result[0] return result[0]
return result return result
def set_value(self, value: float | complex | np.ndarray): def set_value(self, value: float | complex | np.ndarray) -> Self:
self.uniforms["value"][:] = value self.uniforms["value"][:] = value
return self return self

View file

@ -22,7 +22,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple from typing import Callable, Iterable, Sequence, TypeVar, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.coordinate_systems import CoordinateSystem