Add Self type to dot_cloud.py and point_cloud_mobject.py

This commit is contained in:
Grant Sanderson 2023-01-31 13:49:48 -08:00
parent 3779577d9f
commit af585ca3a1
2 changed files with 18 additions and 17 deletions

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy.typing as npt import numpy.typing as npt
from typing import Sequence, Tuple from typing import Sequence, Tuple, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect3Array
@ -70,7 +70,7 @@ class DotCloud(PMobject):
v_buff_ratio: float = 1.0, v_buff_ratio: float = 1.0,
d_buff_ratio: float = 1.0, d_buff_ratio: float = 1.0,
height: float = DEFAULT_GRID_HEIGHT, height: float = DEFAULT_GRID_HEIGHT,
): ) -> Self:
n_points = n_rows * n_cols * n_layers n_points = n_rows * n_cols * n_layers
points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3)) points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3))
points[:, 0] = points[:, 0] % n_cols points[:, 0] = points[:, 0] % n_cols
@ -96,7 +96,7 @@ class DotCloud(PMobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def set_radii(self, radii: npt.ArrayLike): def set_radii(self, radii: npt.ArrayLike) -> Self:
n_points = self.get_num_points() n_points = self.get_num_points()
radii = np.array(radii).reshape((len(radii), 1)) radii = np.array(radii).reshape((len(radii), 1))
self.data["radius"][:] = resize_with_interpolation(radii, n_points) self.data["radius"][:] = resize_with_interpolation(radii, n_points)
@ -107,7 +107,7 @@ class DotCloud(PMobject):
return self.data["radius"] return self.data["radius"]
@Mobject.affects_data @Mobject.affects_data
def set_radius(self, radius: float): def set_radius(self, radius: float) -> Self:
data = self.data if self.get_num_points() > 0 else self._data_defaults data = self.data if self.get_num_points() > 0 else self._data_defaults
data["radius"][:] = radius data["radius"][:] = radius
self.refresh_bounding_box() self.refresh_bounding_box()
@ -116,13 +116,14 @@ class DotCloud(PMobject):
def get_radius(self) -> float: def get_radius(self) -> float:
return self.get_radii().max() return self.get_radii().max()
def set_glow_factor(self, glow_factor: float) -> None: def set_glow_factor(self, glow_factor: float) -> Self:
self.uniforms["glow_factor"] = glow_factor self.uniforms["glow_factor"] = glow_factor
return self
def get_glow_factor(self) -> float: def get_glow_factor(self) -> float:
return self.uniforms["glow_factor"] return self.uniforms["glow_factor"]
def compute_bounding_box(self) -> np.ndarray: def compute_bounding_box(self) -> Vect3Array:
bb = super().compute_bounding_box() bb = super().compute_bounding_box()
radius = self.get_radius() radius = self.get_radius()
bb[0] += np.full((3,), -radius) bb[0] += np.full((3,), -radius)
@ -134,7 +135,7 @@ class DotCloud(PMobject):
scale_factor: float | npt.ArrayLike, scale_factor: float | npt.ArrayLike,
scale_radii: bool = True, scale_radii: bool = True,
**kwargs **kwargs
): ) -> Self:
super().scale(scale_factor, **kwargs) super().scale(scale_factor, **kwargs)
if scale_radii: if scale_radii:
self.set_radii(scale_factor * self.get_radii()) self.set_radii(scale_factor * self.get_radii())
@ -145,7 +146,7 @@ class DotCloud(PMobject):
reflectiveness: float = 0.5, reflectiveness: float = 0.5,
gloss: float = 0.1, gloss: float = 0.1,
shadow: float = 0.2 shadow: float = 0.2
): ) -> Self:
self.set_shading(reflectiveness, gloss, shadow) self.set_shading(reflectiveness, gloss, shadow)
self.apply_depth_test() self.apply_depth_test()
return self return self

View file

@ -10,7 +10,7 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable from typing import Callable, Self
from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array
@ -28,7 +28,7 @@ class PMobject(Mobject):
rgbas: Vect4Array | None = None, rgbas: Vect4Array | None = None,
color: ManimColor | None = None, color: ManimColor | None = None,
opacity: float | None = None opacity: float | None = None
): ) -> Self:
""" """
points must be a Nx3 numpy array, as must rgbas if it is not None points must be a Nx3 numpy array, as must rgbas if it is not None
""" """
@ -46,13 +46,13 @@ class PMobject(Mobject):
self.data["rgba"][-len(rgbas):] = rgbas self.data["rgba"][-len(rgbas):] = rgbas
return self return self
def add_point(self, point: Vect3, rgba=None, color=None, opacity=None): def add_point(self, point: Vect3, rgba=None, color=None, opacity=None) -> Self:
rgbas = None if rgba is None else [rgba] rgbas = None if rgba is None else [rgba]
self.add_points([point], rgbas, color, opacity) self.add_points([point], rgbas, color, opacity)
return self return self
@Mobject.affects_data @Mobject.affects_data
def set_color_by_gradient(self, *colors: ManimColor): def set_color_by_gradient(self, *colors: ManimColor) -> Self:
self.data["rgba"][:] = np.array(list(map( self.data["rgba"][:] = np.array(list(map(
color_to_rgba, color_to_rgba,
color_gradient(colors, self.get_num_points()) color_gradient(colors, self.get_num_points())
@ -60,20 +60,20 @@ class PMobject(Mobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def match_colors(self, pmobject: PMobject): def match_colors(self, pmobject: PMobject) -> Self:
self.data["rgba"][:] = resize_with_interpolation( self.data["rgba"][:] = resize_with_interpolation(
pmobject.data["rgba"], self.get_num_points() pmobject.data["rgba"], self.get_num_points()
) )
return self return self
@Mobject.affects_data @Mobject.affects_data
def filter_out(self, condition: Callable[[np.ndarray], bool]): def filter_out(self, condition: Callable[[np.ndarray], bool]) -> Self:
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self return self
@Mobject.affects_data @Mobject.affects_data
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]) -> Self:
""" """
function is any map from R^3 to R function is any map from R^3 to R
""" """
@ -85,7 +85,7 @@ class PMobject(Mobject):
return self return self
@Mobject.affects_data @Mobject.affects_data
def ingest_submobjects(self): def ingest_submobjects(self) -> Self:
self.data = np.vstack([ self.data = np.vstack([
sm.data for sm in self.get_family() sm.data for sm in self.get_family()
]) ])
@ -96,7 +96,7 @@ class PMobject(Mobject):
return self.get_points()[int(index)] return self.get_points()[int(index)]
@Mobject.affects_data @Mobject.affects_data
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float) -> Self:
lower_index = int(a * pmobject.get_num_points()) lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points())
self.data = pmobject.data[lower_index:upper_index].copy() self.data = pmobject.data[lower_index:upper_index].copy()