Merge pull request #1981 from 3b1b/add-self-type

Add Self type
This commit is contained in:
Grant Sanderson 2023-02-01 11:27:00 -08:00 committed by GitHub
commit 5527c0706d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 334 additions and 333 deletions

View file

@ -11,7 +11,7 @@ from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, List, Iterable
from typing import Callable, List, Iterable, Self
from manimlib.typing import ManimColor, Vect3
@ -49,7 +49,7 @@ class AnimatedBoundary(VGroup):
lambda m, dt: self.update_boundary_copies(dt)
)
def update_boundary_copies(self, dt: float) -> None:
def update_boundary_copies(self, dt: float) -> Self:
# Not actual time, but something which passes at
# an altered rate to make the implementation below
# cleaner
@ -79,6 +79,7 @@ class AnimatedBoundary(VGroup):
)
self.total_time += dt
return self
def full_family_become_partial(
self,
@ -86,7 +87,7 @@ class AnimatedBoundary(VGroup):
mob2: VMobject,
a: float,
b: float
):
) -> Self:
family1 = mob1.family_members_with_points()
family2 = mob2.family_members_with_points()
for sm1, sm2 in zip(family1, family2):
@ -118,7 +119,7 @@ class TracedPath(VMobject):
self.traced_points: list[np.ndarray] = []
self.add_updater(lambda m, dt: m.update_path(dt))
def update_path(self, dt: float):
def update_path(self, dt: float) -> Self:
if dt == 0:
return self
point = self.traced_point_func().copy()

View file

@ -21,6 +21,7 @@ from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.dot_cloud import DotCloud
from manimlib.mobject.types.surface import ParametricSurface
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.simple_functions import binary_search
from manimlib.utils.space_ops import angle_of_vector
@ -31,7 +32,7 @@ from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Type, TypeVar
from typing import Callable, Iterable, Sequence, Type, TypeVar, Optional, Self
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
@ -235,7 +236,13 @@ class CoordinateSystem(ABC):
"""
return self.input_to_graph_point(x, graph)
def bind_graph_to_func(self, graph, func, jagged=False, get_discontinuities=None):
def bind_graph_to_func(
self,
graph: VMobject,
func: Callable[[Vect3], Vect3],
jagged: bool = False,
get_discontinuities: Optional[Callable[[], Vect3]] = None
) -> VMobject:
"""
Use for graphing functions which might change over time, or change with
conditions
@ -659,7 +666,7 @@ class NumberPlane(Axes):
kwargs["buff"] = 0
return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs)
def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50):
def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50) -> Self:
for mob in self.family_members_with_points():
num_curves = mob.get_num_curves()
if num_inserted_curves > num_curves:
@ -698,7 +705,7 @@ class ComplexPlane(NumberPlane):
skip_first: bool = True,
font_size: int = 36,
**kwargs
):
) -> Self:
if numbers is None:
numbers = self.get_default_coordinate_values(skip_first)

View file

@ -30,7 +30,7 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from typing import Iterable, Self, Optional
from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -67,7 +67,7 @@ class TipableVMobject(VMobject):
)
# Adding, Creating, Modifying tips
def add_tip(self, at_start: bool = False, **kwargs):
def add_tip(self, at_start: bool = False, **kwargs) -> Self:
"""
Adds a tip to the TipableVMobject instance, recognising
that the endpoints might need to be switched if it's
@ -112,7 +112,7 @@ class TipableVMobject(VMobject):
tip.shift(anchor - tip.get_tip_point())
return tip
def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool):
def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool) -> Self:
if self.get_length() == 0:
# Zero length, put_start_and_end_on wouldn't
# work
@ -127,7 +127,7 @@ class TipableVMobject(VMobject):
self.put_start_and_end_on(start, end)
return self
def asign_tip_attr(self, tip: ArrowTip, at_start: bool):
def asign_tip_attr(self, tip: ArrowTip, at_start: bool) -> Self:
if at_start:
self.start_tip = tip
else:
@ -258,7 +258,7 @@ class Arc(TipableVMobject):
angle = angle_of_vector(self.get_end() - self.get_arc_center())
return angle % TAU
def move_arc_center_to(self, point: Vect3):
def move_arc_center_to(self, point: Vect3) -> Self:
self.shift(point - self.get_arc_center())
return self
@ -318,7 +318,7 @@ class Circle(Arc):
dim_to_match: int = 0,
stretch: bool = False,
buff: float = MED_SMALL_BUFF
):
) -> Self:
self.replace(mobject, dim_to_match, stretch)
self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0)
self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1)
@ -475,7 +475,7 @@ class Line(TipableVMobject):
end: Vect3,
buff: float = 0,
path_arc: float = 0
):
) -> Self:
vect = end - start
dist = get_norm(vect)
if np.isclose(dist, 0):
@ -504,9 +504,10 @@ class Line(TipableVMobject):
self.set_points_as_corners([start, end])
return self
def set_path_arc(self, new_value: float) -> None:
def set_path_arc(self, new_value: float) -> Self:
self.path_arc = new_value
self.init_points()
return self
def set_start_and_end_attrs(self, start: Vect3 | Mobject, end: Vect3 | Mobject):
# If either start or end are Mobjects, this
@ -541,7 +542,7 @@ class Line(TipableVMobject):
result[:len(point)] = point
return result
def put_start_and_end_on(self, start: Vect3, end: Vect3):
def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
curr_start, curr_end = self.get_start_and_end()
if np.isclose(curr_start, curr_end).all():
# Handle null lines more gracefully
@ -569,7 +570,7 @@ class Line(TipableVMobject):
def get_slope(self) -> float:
return np.tan(self.get_angle())
def set_angle(self, angle: float, about_point: Vect3 | None = None):
def set_angle(self, angle: float, about_point: Optional[Vect3] = None) -> Self:
if about_point is None:
about_point = self.get_start()
self.rotate(
@ -695,13 +696,13 @@ class Arrow(Line):
end: Vect3,
buff: float = 0,
path_arc: float = 0
):
) -> Self:
super().set_points_by_ends(start, end, buff, path_arc)
self.insert_tip_anchor()
self.create_tip_with_stroke_width()
return self
def insert_tip_anchor(self):
def insert_tip_anchor(self) -> Self:
prev_end = self.get_end()
arc_len = self.get_arc_length()
tip_len = self.get_stroke_width() * self.width_to_tip_len * self.tip_width_ratio
@ -716,7 +717,7 @@ class Arrow(Line):
return self
@Mobject.affects_data
def create_tip_with_stroke_width(self):
def create_tip_with_stroke_width(self) -> Self:
if self.get_num_points() < 3:
return self
tip_width = self.tip_width_ratio * min(
@ -727,7 +728,7 @@ class Arrow(Line):
self.data['stroke_width'][-3:, 0] = tip_width * np.linspace(1, 0, 3)
return self
def reset_tip(self):
def reset_tip(self) -> Self:
self.set_points_by_ends(
self.get_start(), self.get_end(),
path_arc=self.path_arc
@ -739,13 +740,13 @@ class Arrow(Line):
color: ManimColor | Iterable[ManimColor] | None = None,
width: float | Iterable[float] | None = None,
*args, **kwargs
):
) -> Self:
super().set_stroke(color=color, width=width, *args, **kwargs)
if self.has_points():
self.reset_tip()
return self
def _handle_scale_side_effects(self, scale_factor: float):
def _handle_scale_side_effects(self, scale_factor: float) -> Self:
if scale_factor != 1.0:
self.reset_tip()
return self
@ -787,7 +788,7 @@ class FillArrow(Line):
end: Vect3,
buff: float = 0,
path_arc: float = 0
) -> None:
) -> Self:
# Find the right tip length and thickness
vect = end - start
length = max(get_norm(vect), 1e-8)
@ -848,8 +849,9 @@ class FillArrow(Line):
axis=rotate_vector(self.get_unit_vector(), -PI / 2),
)
self.shift(start - self.get_start())
return self
def reset_points_around_ends(self):
def reset_points_around_ends(self) -> Self:
self.set_points_by_ends(
self.get_start().copy(),
self.get_end().copy(),
@ -864,21 +866,21 @@ class FillArrow(Line):
def get_end(self) -> Vect3:
return self.get_points()[self.tip_index]
def put_start_and_end_on(self, start: Vect3, end: Vect3):
def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc)
return self
def scale(self, *args, **kwargs):
def scale(self, *args, **kwargs) -> Self:
super().scale(*args, **kwargs)
self.reset_points_around_ends()
return self
def set_thickness(self, thickness: float):
def set_thickness(self, thickness: float) -> Self:
self.thickness = thickness
self.reset_points_around_ends()
return self
def set_path_arc(self, path_arc: float):
def set_path_arc(self, path_arc: float) -> Self:
self.path_arc = path_arc
self.reset_points_around_ends()
return self
@ -921,7 +923,7 @@ class Polygon(VMobject):
def get_vertices(self) -> Vect3Array:
return self.get_start_anchors()
def round_corners(self, radius: float | None = None):
def round_corners(self, radius: Optional[float] = None) -> Self:
if radius is None:
verts = self.get_vertices()
min_edge_length = min(

View file

@ -18,7 +18,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
from typing import Sequence, Self
import numpy.typing as npt
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3
@ -129,7 +129,7 @@ class Matrix(VMobject):
v_buff: float,
h_buff: float,
aligned_corner: Vect3,
):
) -> Self:
for i, row in enumerate(matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
@ -139,7 +139,7 @@ class Matrix(VMobject):
)
return self
def add_brackets(self, v_buff: float, h_buff: float):
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix)
brackets = Tex("".join((
R"\left[\begin{array}{c}",
@ -168,13 +168,13 @@ class Matrix(VMobject):
for row in self.mob_matrix
])
def set_column_colors(self, *colors: ManimColor):
def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns()
for color, column in zip(colors, columns):
column.set_color(color)
return self
def add_background_to_entries(self):
def add_background_to_entries(self) -> Self:
for mob in self.get_entries():
mob.add_background_rectangle()
return self

File diff suppressed because it is too large Load diff

View file

@ -11,7 +11,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import TypeVar
from typing import TypeVar, Self
from manimlib.typing import ManimColor, Vect3
T = TypeVar("T", bound=VMobject)
@ -163,7 +163,7 @@ class DecimalNumber(VMobject):
def get_tex(self):
return self.num_string
def set_value(self, number: float | complex):
def set_value(self, number: float | complex) -> Self:
move_to_point = self.get_edge_center(self.edge_to_fix)
style = self.family_members_with_points()[0].get_style()
self.set_submobjects_from_number(number)
@ -171,14 +171,16 @@ class DecimalNumber(VMobject):
self.set_style(**style)
return self
def _handle_scale_side_effects(self, scale_factor: float) -> None:
def _handle_scale_side_effects(self, scale_factor: float) -> Self:
self.uniforms["font_size"] = scale_factor * self.uniforms["font_size"]
return self
def get_value(self) -> float | complex:
return self.number
def increment_value(self, delta_t: float | complex = 1) -> None:
def increment_value(self, delta_t: float | complex = 1) -> Self:
self.set_value(self.get_value() + delta_t)
return self
class Integer(DecimalNumber):

View file

@ -14,7 +14,7 @@ from manimlib.utils.customization import get_customization
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
from typing import Sequence, Self
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor
@ -60,20 +60,20 @@ class BackgroundRectangle(SurroundingRectangle):
)
self.original_fill_opacity = fill_opacity
def pointwise_become_partial(self, mobject: Mobject, a: float, b: float):
def pointwise_become_partial(self, mobject: Mobject, a: float, b: float) -> Self:
self.set_fill(opacity=b * self.original_fill_opacity)
return self
def set_style_data(
def set_style(
self,
stroke_color: ManimColor | None = None,
stroke_width: float | None = None,
fill_color: ManimColor | None = None,
fill_opacity: float | None = None,
family: bool = True
):
) -> Self:
# Unchangeable style, except for fill_opacity
VMobject.set_style_data(
VMobject.set_style(
self,
stroke_color=BLACK,
stroke_width=0,

View file

@ -166,7 +166,6 @@ class Cylinder(Surface):
self.scale(self.radius)
self.set_depth(self.height, stretch=True)
self.apply_matrix(z_to_vector(self.axis))
return self
def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([np.cos(u), np.sin(u), v])
@ -186,6 +185,7 @@ class Line3D(Cylinder):
height=get_norm(axis),
radius=width / 2,
axis=axis,
resolution=resolution,
**kwargs
)
self.shift((start + end) / 2)
@ -376,16 +376,6 @@ class Dodecahedron(VGroup3D):
super().__init__(*pentagons, **style)
# # Rotate those two pentagons by all the axis permuations to fill
# # out the dodecahedron
# Id = np.identity(3)
# for i in range(3):
# perm = [j % 3 for j in range(i, i + 3)]
# for b in [1, -1]:
# matrix = b * np.array([Id[0][perm], Id[1][perm], Id[2][perm]])
# self.add(pentagon1.copy().apply_matrix(matrix, about_point=ORIGIN))
# self.add(pentagon2.copy().apply_matrix(matrix, about_point=ORIGIN))
class Prismify(VGroup3D):
def __init__(self, vmobject, depth=1.0, direction=IN, **kwargs):

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy.typing as npt
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -70,7 +70,7 @@ class DotCloud(PMobject):
v_buff_ratio: float = 1.0,
d_buff_ratio: float = 1.0,
height: float = DEFAULT_GRID_HEIGHT,
):
) -> Self:
n_points = n_rows * n_cols * n_layers
points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3))
points[:, 0] = points[:, 0] % n_cols
@ -96,7 +96,7 @@ class DotCloud(PMobject):
return self
@Mobject.affects_data
def set_radii(self, radii: npt.ArrayLike):
def set_radii(self, radii: npt.ArrayLike) -> Self:
n_points = self.get_num_points()
radii = np.array(radii).reshape((len(radii), 1))
self.data["radius"][:] = resize_with_interpolation(radii, n_points)
@ -107,7 +107,7 @@ class DotCloud(PMobject):
return self.data["radius"]
@Mobject.affects_data
def set_radius(self, radius: float):
def set_radius(self, radius: float) -> Self:
data = self.data if self.get_num_points() > 0 else self._data_defaults
data["radius"][:] = radius
self.refresh_bounding_box()
@ -116,13 +116,14 @@ class DotCloud(PMobject):
def get_radius(self) -> float:
return self.get_radii().max()
def set_glow_factor(self, glow_factor: float) -> None:
def set_glow_factor(self, glow_factor: float) -> Self:
self.uniforms["glow_factor"] = glow_factor
return self
def get_glow_factor(self) -> float:
return self.uniforms["glow_factor"]
def compute_bounding_box(self) -> np.ndarray:
def compute_bounding_box(self) -> Vect3Array:
bb = super().compute_bounding_box()
radius = self.get_radius()
bb[0] += np.full((3,), -radius)
@ -134,7 +135,7 @@ class DotCloud(PMobject):
scale_factor: float | npt.ArrayLike,
scale_radii: bool = True,
**kwargs
):
) -> Self:
super().scale(scale_factor, **kwargs)
if scale_radii:
self.set_radii(scale_factor * self.get_radii())
@ -145,7 +146,7 @@ class DotCloud(PMobject):
reflectiveness: float = 0.5,
gloss: float = 0.1,
shadow: float = 0.2
):
) -> Self:
self.set_shading(reflectiveness, gloss, shadow)
self.apply_depth_test()
return self

View file

@ -10,7 +10,7 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable
from typing import Callable, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array
@ -28,7 +28,7 @@ class PMobject(Mobject):
rgbas: Vect4Array | None = None,
color: ManimColor | None = None,
opacity: float | None = None
):
) -> Self:
"""
points must be a Nx3 numpy array, as must rgbas if it is not None
"""
@ -46,13 +46,13 @@ class PMobject(Mobject):
self.data["rgba"][-len(rgbas):] = rgbas
return self
def add_point(self, point: Vect3, rgba=None, color=None, opacity=None):
def add_point(self, point: Vect3, rgba=None, color=None, opacity=None) -> Self:
rgbas = None if rgba is None else [rgba]
self.add_points([point], rgbas, color, opacity)
return self
@Mobject.affects_data
def set_color_by_gradient(self, *colors: ManimColor):
def set_color_by_gradient(self, *colors: ManimColor) -> Self:
self.data["rgba"][:] = np.array(list(map(
color_to_rgba,
color_gradient(colors, self.get_num_points())
@ -60,20 +60,20 @@ class PMobject(Mobject):
return self
@Mobject.affects_data
def match_colors(self, pmobject: PMobject):
def match_colors(self, pmobject: PMobject) -> Self:
self.data["rgba"][:] = resize_with_interpolation(
pmobject.data["rgba"], self.get_num_points()
)
return self
@Mobject.affects_data
def filter_out(self, condition: Callable[[np.ndarray], bool]):
def filter_out(self, condition: Callable[[np.ndarray], bool]) -> Self:
for mob in self.family_members_with_points():
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self
@Mobject.affects_data
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]) -> Self:
"""
function is any map from R^3 to R
"""
@ -85,7 +85,7 @@ class PMobject(Mobject):
return self
@Mobject.affects_data
def ingest_submobjects(self):
def ingest_submobjects(self) -> Self:
self.data = np.vstack([
sm.data for sm in self.get_family()
])
@ -96,7 +96,7 @@ class PMobject(Mobject):
return self.get_points()[int(index)]
@Mobject.affects_data
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float) -> Self:
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())
self.data = pmobject.data[lower_index:upper_index].copy()

View file

@ -17,7 +17,7 @@ from manimlib.utils.space_ops import cross
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, Tuple
from typing import Callable, Iterable, Sequence, Tuple, Self
from manimlib.camera.camera import Camera
from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -100,18 +100,19 @@ class Surface(Mobject):
(dv_points - points) / self.epsilon,
), 1)
def apply_points_function(self, *args, **kwargs):
def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs)
self.get_unit_normals()
return self
def compute_triangle_indices(self):
def compute_triangle_indices(self) -> np.ndarray:
# TODO, if there is an event which changes
# the resolution of the surface, make sure
# this is called.
nu, nv = self.resolution
if nu == 0 or nv == 0:
self.triangle_indices = np.zeros(0, dtype=int)
return
return self.triangle_indices
index_grid = np.arange(nu * nv).reshape((nu, nv))
indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int)
indices[0::6] = index_grid[:-1, :-1].flatten() # Top left
@ -121,6 +122,7 @@ class Surface(Mobject):
indices[4::6] = index_grid[+1:, :-1].flatten() # Bottom left
indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right
self.triangle_indices = indices
return self.triangle_indices
def get_triangle_indices(self) -> np.ndarray:
return self.triangle_indices
@ -154,7 +156,7 @@ class Surface(Mobject):
a: float,
b: float,
axis: int | None = None
):
) -> Self:
assert(isinstance(smobject, Surface))
if axis is None:
axis = self.prefered_creation_axis
@ -211,7 +213,7 @@ class Surface(Mobject):
return points.reshape((nu * nv, *resolution[2:]))
@Mobject.affects_data
def sort_faces_back_to_front(self, vect: Vect3 = OUT):
def sort_faces_back_to_front(self, vect: Vect3 = OUT) -> Self:
tri_is = self.triangle_indices
points = self.get_points()
@ -221,24 +223,25 @@ class Surface(Mobject):
tri_is[k::3] = tri_is[k::3][indices]
return self
def always_sort_to_camera(self, camera: Camera):
def always_sort_to_camera(self, camera: Camera) -> Self:
def updater(surface: Surface):
vect = camera.get_location() - surface.get_center()
surface.sort_faces_back_to_front(vect)
self.add_updater(updater)
return self
def set_clip_plane(
self,
vect: Vect3 | None = None,
threshold: float | None = None
):
) -> Self:
if vect is not None:
self.uniforms["clip_plane"][:3] = vect
if threshold is not None:
self.uniforms["clip_plane"][3] = threshold
return self
def deactivate_clip_plane(self):
def deactivate_clip_plane(self) -> Self:
self.uniforms["clip_plane"][:] = 0
return self
@ -335,7 +338,7 @@ class TexturedSurface(Surface):
self.uniforms["num_textures"] = self.num_textures
@Mobject.affects_data
def set_opacity(self, opacity: float | Iterable[float]):
def set_opacity(self, opacity: float | Iterable[float]) -> Self:
op_arr = np.array(listify(opacity))
self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))
return self
@ -345,7 +348,7 @@ class TexturedSurface(Surface):
color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None,
recurse: bool = True
):
) -> Self:
if opacity is not None:
self.set_opacity(opacity)
return self
@ -356,7 +359,7 @@ class TexturedSurface(Surface):
a: float,
b: float,
axis: int = 1
):
) -> Self:
super().pointwise_become_partial(tsmobject, a, b, axis)
im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"]

View file

@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple
from typing import Callable, Iterable, Tuple, Any, Self
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
from moderngl.context import Context
@ -128,29 +128,10 @@ class VMobject(Mobject):
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke)
# These are here just to make type checkers happy
def get_family(self, recurse: bool = True) -> list[VMobject]:
return super().get_family(recurse)
def family_members_with_points(self) -> list[VMobject]:
return super().family_members_with_points()
def replicate(self, n: int) -> VGroup:
return super().replicate(n)
def get_grid(self, *args, **kwargs) -> VGroup:
return super().get_grid(*args, **kwargs)
def __getitem__(self, value: int | slice) -> VMobject:
return super().__getitem__(value)
def __iter__(self) -> Iterable[VMobject]:
return super().__iter__()
def add(self, *vmobjects: VMobject):
def add(self, *vmobjects: VMobject) -> Self:
if not all((isinstance(m, VMobject) for m in vmobjects)):
raise Exception("All submobjects must be of type VMobject")
super().add(*vmobjects)
return super().add(*vmobjects)
# Colors
def init_colors(self):
@ -175,7 +156,7 @@ class VMobject(Mobject):
rgba_array: Vect4Array,
name: str | None = None,
recurse: bool = False
):
) -> Self:
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
@ -191,7 +172,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None,
border_width: float | None = None,
recurse: bool = True
):
) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None:
for mob in self.get_family(recurse):
@ -205,7 +186,7 @@ class VMobject(Mobject):
opacity: float | Iterable[float] | None = None,
background: bool | None = None,
recurse: bool = True
):
) -> Self:
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None:
@ -228,7 +209,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] = BLACK,
width: float | Iterable[float] = 3,
background: bool = True
):
) -> Self:
self.set_stroke(color, width, background=background)
return self
@ -245,7 +226,7 @@ class VMobject(Mobject):
stroke_background: bool = True,
shading: Tuple[float, float, float] | None = None,
recurse: bool = True
):
) -> Self:
for mob in self.get_family(recurse):
if fill_rgba is not None:
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
@ -276,7 +257,7 @@ class VMobject(Mobject):
mob.set_shading(*shading, recurse=False)
return self
def get_style(self):
def get_style(self) -> dict[str, Any]:
data = self.data if self.get_num_points() > 0 else self._data_defaults
return {
"fill_rgba": data['fill_rgba'].copy(),
@ -286,7 +267,7 @@ class VMobject(Mobject):
"shading": self.get_shading(),
}
def match_style(self, vmobject: VMobject, recurse: bool = True):
def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self:
self.set_style(**vmobject.get_style(), recurse=False)
if recurse:
# Does its best to match up submobject lists, and
@ -305,7 +286,7 @@ class VMobject(Mobject):
color: ManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float] | None = None,
recurse: bool = True
):
) -> Self:
self.set_fill(color, opacity=opacity, recurse=recurse)
self.set_stroke(color, opacity=opacity, recurse=recurse)
return self
@ -314,16 +295,16 @@ class VMobject(Mobject):
self,
opacity: float | Iterable[float] | None,
recurse: bool = True
):
) -> Self:
self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True):
def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self:
self.set_uniform(recurse, anti_alias_width=anti_alias_width)
return self
def fade(self, darkness: float = 0.5, recurse: bool = True):
def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self:
mobs = self.get_family() if recurse else [self]
for mob in mobs:
factor = 1.0 - darkness
@ -407,7 +388,7 @@ class VMobject(Mobject):
return self.get_fill_opacity()
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True):
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
mob.uniforms["flat_stroke"] = float(flat_stroke)
return self
@ -415,7 +396,7 @@ class VMobject(Mobject):
def get_flat_stroke(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True):
def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]
return self
@ -428,7 +409,7 @@ class VMobject(Mobject):
anti_alias_width: float = 0,
fill_border_width: float = 0,
recurse: bool=True
):
) -> Self:
super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width)
@ -439,14 +420,14 @@ class VMobject(Mobject):
anti_alias_width: float = 1.0,
fill_border_width: float = 0.5,
recurse: bool=True
):
) -> Self:
super().apply_depth_test(recurse)
self.set_anti_alias_width(anti_alias_width)
self.set_fill(border_width=fill_border_width)
return self
@Mobject.affects_family_data
def use_winding_fill(self, value: bool = True, recurse: bool = True):
def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self:
for submob in self.get_family(recurse):
submob._use_winding_fill = value
if not value and submob.has_points():
@ -458,7 +439,7 @@ class VMobject(Mobject):
self,
anchors: Vect3Array,
handles: Vect3Array,
):
) -> Self:
assert(len(anchors) == len(handles) + 1)
points = resize_array(self.get_points(), 2 * len(anchors) - 1)
points[0::2] = anchors
@ -466,7 +447,7 @@ class VMobject(Mobject):
self.set_points(points)
return self
def start_new_path(self, point: Vect3):
def start_new_path(self, point: Vect3) -> Self:
# Path ends are signaled by a handle point sitting directly
# on top of the previous anchor
if self.has_points():
@ -481,7 +462,7 @@ class VMobject(Mobject):
handle1: Vect3,
handle2: Vect3,
anchor2: Vect3
):
) -> Self:
self.start_new_path(anchor1)
self.add_cubic_bezier_curve_to(handle1, handle2, anchor2)
return self
@ -491,7 +472,7 @@ class VMobject(Mobject):
handle1: Vect3,
handle2: Vect3,
anchor: Vect3,
):
) -> Self:
"""
Add cubic bezier curve to the path.
"""
@ -513,7 +494,7 @@ class VMobject(Mobject):
self.append_points(quad_approx[1:])
return self
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3):
def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self:
self.throw_error_if_no_points()
last_point = self.get_last_point()
if self.consider_points_equal(handle, last_point):
@ -522,14 +503,14 @@ class VMobject(Mobject):
self.append_points([handle, anchor])
return self
def add_line_to(self, point: Vect3):
def add_line_to(self, point: Vect3) -> Self:
self.throw_error_if_no_points()
last_point = self.get_last_point()
alphas = np.linspace(0, 1, 5 if self.long_lines else 3)
self.append_points(outer_interpolate(last_point, point, alphas[1:]))
return self
def add_smooth_curve_to(self, point: Vect3):
def add_smooth_curve_to(self, point: Vect3) -> Self:
if self.has_new_path_started():
self.add_line_to(point)
else:
@ -538,7 +519,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point)
return self
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3):
def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self:
self.throw_error_if_no_points()
if self.get_num_points() == 1:
new_handle = handle
@ -559,7 +540,7 @@ class VMobject(Mobject):
points = self.get_points()
return 2 * points[-1] - points[-2]
def close_path(self, smooth: bool = False):
def close_path(self, smooth: bool = False) -> Self:
if self.is_closed():
return self
last_path_start = self.get_subpaths()[-1][0]
@ -577,7 +558,7 @@ class VMobject(Mobject):
self,
tuple_to_subdivisions: Callable,
recurse: bool = True
):
) -> Self:
for vmob in self.get_family(recurse):
if not vmob.has_points():
continue
@ -599,7 +580,7 @@ class VMobject(Mobject):
self,
angle_threshold: float = 30 * DEGREES,
recurse: bool = True
):
) -> Self:
def tuple_to_subdivisions(b0, b1, b2):
angle = angle_between_vectors(b1 - b0, b2 - b1)
return int(angle / angle_threshold)
@ -607,7 +588,7 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1):
def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self:
path = self.get_anchors()
def tuple_to_subdivisions(b0, b1, b2):
if line_intersects_path(b0, b1, path):
@ -617,12 +598,12 @@ class VMobject(Mobject):
self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse)
return self
def add_points_as_corners(self, points: Iterable[Vect3]):
def add_points_as_corners(self, points: Iterable[Vect3]) -> Self:
for point in points:
self.add_line_to(point)
return points
return self
def set_points_as_corners(self, points: Iterable[Vect3]):
def set_points_as_corners(self, points: Iterable[Vect3]) -> Self:
anchors = np.array(points)
handles = 0.5 * (anchors[:-1] + anchors[1:])
self.set_anchors_and_handles(anchors, handles)
@ -632,7 +613,7 @@ class VMobject(Mobject):
self,
points: Iterable[Vect3],
approx: bool = True
):
) -> Self:
self.set_points_as_corners(points)
self.make_smooth(approx=approx)
return self
@ -641,7 +622,7 @@ class VMobject(Mobject):
dots = self.get_joint_products()[::2, 3]
return bool((dots > 1 - 1e-3).all())
def change_anchor_mode(self, mode: str):
def change_anchor_mode(self, mode: str) -> Self:
assert(mode in ("jagged", "approx_smooth", "true_smooth"))
subpaths = self.get_subpaths()
self.clear_points()
@ -664,7 +645,7 @@ class VMobject(Mobject):
self.add_subpath(new_subpath)
return self
def make_smooth(self, approx=False, recurse=True):
def make_smooth(self, approx=False, recurse=True) -> Self:
"""
Edits the path so as to pass smoothly through all
the current anchor points.
@ -679,15 +660,16 @@ class VMobject(Mobject):
submob.change_anchor_mode(mode)
return self
def make_approximately_smooth(self, recurse=True):
def make_approximately_smooth(self, recurse=True) -> Self:
self.make_smooth(approx=True, recurse=recurse)
return self
def make_jagged(self, recurse=True):
def make_jagged(self, recurse=True) -> Self:
for submob in self.get_family(recurse):
submob.change_anchor_mode("jagged")
return self
def add_subpath(self, points: Vect3Array):
def add_subpath(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 1 or len(points) == 0)
if not self.has_points():
self.set_points(points)
@ -697,7 +679,7 @@ class VMobject(Mobject):
self.append_points(points[1:])
return self
def append_vectorized_mobject(self, vmobject: VMobject):
def append_vectorized_mobject(self, vmobject: VMobject) -> Self:
self.add_subpath(vmobject.get_points())
n = vmobject.get_num_points()
self.data[-n:] = vmobject.data
@ -715,7 +697,7 @@ class VMobject(Mobject):
def get_bezier_tuples(self) -> Iterable[Vect3Array]:
return self.get_bezier_tuples_from_points(self.get_points())
def get_subpath_end_indices_from_points(self, points: Vect3Array):
def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray:
atol = self.tolerance_for_point_equality
a0, h, a1 = points[0:-1:2], points[1::2], points[2::2]
# An anchor point is considered the end of a path
@ -731,7 +713,7 @@ class VMobject(Mobject):
is_end[:-1] = is_end[:-1] & ~is_end[1:]
return np.array([2 * n for n, end in enumerate(is_end) if end])
def get_subpath_end_indices(self):
def get_subpath_end_indices(self) -> np.ndarray:
return self.get_subpath_end_indices_from_points(self.get_points())
def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]:
@ -860,7 +842,7 @@ class VMobject(Mobject):
self.data["unit_normal"][:] = normal
return normal
def refresh_unit_normal(self):
def refresh_unit_normal(self) -> Self:
self.get_unit_normal()
return self
@ -870,20 +852,20 @@ class VMobject(Mobject):
axis: Vect3 = OUT,
about_point: Vect3 | None = None,
**kwargs
):
) -> Self:
super().rotate(angle, axis, about_point, **kwargs)
for mob in self.get_family():
mob.refresh_unit_normal()
return self
def ensure_positive_orientation(self, recurse=True):
def ensure_positive_orientation(self, recurse=True) -> Self:
for mob in self.get_family(recurse):
if mob.get_unit_normal()[2] < 0:
mob.reverse_points()
return self
# Alignment
def align_points(self, vmobject: VMobject):
def align_points(self, vmobject: VMobject) -> Self:
winding = self._use_winding_fill and vmobject._use_winding_fill
self.use_winding_fill(winding)
vmobject.use_winding_fill(winding)
@ -940,7 +922,7 @@ class VMobject(Mobject):
mob.get_joint_products()
return self
def invisible_copy(self):
def invisible_copy(self) -> Self:
result = self.copy()
if not result.has_fill() or result.get_num_points() == 0:
return result
@ -948,14 +930,14 @@ class VMobject(Mobject):
result.set_opacity(0)
return result
def insert_n_curves(self, n: int, recurse: bool = True):
def insert_n_curves(self, n: int, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array):
def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array:
if len(points) == 1:
return np.repeat(points, 2 * n + 1, 0)
@ -988,7 +970,7 @@ class VMobject(Mobject):
mobject2: VMobject,
alpha: float,
*args, **kwargs
):
) -> Self:
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill() and not self._use_winding_fill:
tri1 = mobject1.get_triangulation()
@ -997,7 +979,7 @@ class VMobject(Mobject):
self.refresh_triangulation()
return self
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float):
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self:
assert(isinstance(vmobject, VMobject))
vm_points = vmobject.get_points()
self.data["joint_product"] = vmobject.data["joint_product"]
@ -1040,12 +1022,12 @@ class VMobject(Mobject):
self.set_points(new_points, refresh_joints=False)
return self
def get_subcurve(self, a: float, b: float) -> VMobject:
def get_subcurve(self, a: float, b: float) -> Self:
vmob = self.copy()
vmob.pointwise_become_partial(self, a, b)
return vmob
def get_outer_vert_indices(self):
def get_outer_vert_indices(self) -> np.ndarray:
"""
Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
"""
@ -1056,12 +1038,12 @@ class VMobject(Mobject):
# Data for shaders that may need refreshing
def refresh_triangulation(self):
def refresh_triangulation(self) -> Self:
for mob in self.get_family():
mob.needs_new_triangulation = True
return self
def get_triangulation(self):
def get_triangulation(self) -> np.ndarray:
# Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader.
# First triangles come directly from the points
@ -1118,12 +1100,12 @@ class VMobject(Mobject):
self.needs_new_triangulation = False
return tri_indices
def refresh_joint_products(self):
def refresh_joint_products(self) -> Self:
for mob in self.get_family():
mob.needs_new_joint_products = True
return self
def get_joint_products(self, refresh: bool = False):
def get_joint_products(self, refresh: bool = False) -> np.ndarray:
"""
The 'joint product' is a 4-vector holding the cross and dot
product between tangent vectors at a joint
@ -1174,10 +1156,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject):
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self:
for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2)
return self
def triggers_refreshed_triangulation(func: Callable):
@wraps(func)
@ -1189,7 +1172,7 @@ class VMobject(Mobject):
return self
return wrapper
def set_points(self, points: Vect3Array, refresh_joints: bool = True):
def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self:
assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points)
self.refresh_triangulation()
@ -1199,13 +1182,13 @@ class VMobject(Mobject):
return self
@triggers_refreshed_triangulation
def append_points(self, points: Vect3Array):
def append_points(self, points: Vect3Array) -> Self:
assert(len(points) % 2 == 0)
super().append_points(points)
return self
@triggers_refreshed_triangulation
def reverse_points(self, recurse: bool = True):
def reverse_points(self, recurse: bool = True) -> Self:
# This will reset which anchors are
# considered path ends
for mob in self.get_family(recurse):
@ -1218,7 +1201,7 @@ class VMobject(Mobject):
return self
@triggers_refreshed_triangulation
def set_data(self, data: np.ndarray):
def set_data(self, data: np.ndarray) -> Self:
super().set_data(data)
return self
@ -1229,15 +1212,16 @@ class VMobject(Mobject):
function: Callable[[Vect3], Vect3],
make_smooth: bool = False,
**kwargs
):
) -> Self:
super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth:
self.make_smooth(approx=True)
return self
def apply_points_function(self, *args, **kwargs):
def apply_points_function(self, *args, **kwargs) -> Self:
super().apply_points_function(*args, **kwargs)
self.refresh_joint_products()
return self
# For shaders
def init_shader_data(self, ctx: Context):
@ -1272,7 +1256,7 @@ class VMobject(Mobject):
self.stroke_shader_wrapper,
]
def refresh_shader_wrapper_id(self):
def refresh_shader_wrapper_id(self) -> Self:
if not self._shaders_initialized:
return self
for wrapper in self.shader_wrappers:
@ -1339,7 +1323,7 @@ class VGroup(VMobject):
super().__init__(**kwargs)
self.add(*vmobjects)
def __add__(self, other: VMobject | VGroup):
def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject))
return self.add(other)

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import numpy as np
from typing import Self
from manimlib.mobject.mobject import Mobject
from manimlib.utils.iterables import listify
@ -36,7 +36,7 @@ class ValueTracker(Mobject):
return result[0]
return result
def set_value(self, value: float | complex | np.ndarray):
def set_value(self, value: float | complex | np.ndarray) -> Self:
self.uniforms["value"][:] = value
return self

View file

@ -22,7 +22,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.mobject.coordinate_systems import CoordinateSystem