diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 300a279e..14ca6db2 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -106,6 +106,10 @@ class Mobject(object): self.bounding_box: Vect3Array = np.zeros((3, 3)) self.init_data() + self._data_defaults = { + key: np.zeros((1, self.data[key].shape[1])) + for key in self.data + } self.init_uniforms() self.init_updaters() self.init_event_listners() @@ -130,7 +134,7 @@ class Mobject(object): def init_data(self): self.data: dict[str, np.ndarray] = { "points": np.zeros((0, 3)), - "rgbas": np.zeros((1, 4)), + "rgbas": np.zeros((0, 4)), } def init_uniforms(self): @@ -172,7 +176,15 @@ class Mobject(object): new_length: int, resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array ): - for key in self.aligned_data_keys: + if new_length == 0: + for key in self.data: + if len(self.data[key]) > 0: + self._data_defaults[key][:1] = self.data[key][:1] + elif self.get_num_points() == 0: + for key in self.data: + self.data[key] = self._data_defaults[key].copy() + + for key in self.data: self.data[key] = resize_func(self.data[key], new_length) self.refresh_bounding_box() return self @@ -1215,7 +1227,8 @@ class Mobject(object): recurse: bool = False ): for mob in self.get_family(recurse): - mob.data[name] = np.array(rgba_array) + data = mob.data if mob.get_num_points() > 0 else mob._data_defaults + data[name][:] = rgba_array return self def set_color_by_rgba_func( @@ -1252,22 +1265,14 @@ class Mobject(object): name: str = "rgbas", recurse: bool = True ): - max_len = 0 - if color is not None: - rgbs = np.array([color_to_rgb(c) for c in listify(color)]) - max_len = len(rgbs) - if opacity is not None: - opacities = np.array(listify(opacity)) - max_len = max(max_len, len(opacities)) - for mob in self.get_family(recurse): - if max_len > len(mob.data[name]): - mob.data[name] = resize_array(mob.data[name], max_len) - size = len(mob.data[name]) + data = mob.data if mob.has_points() > 0 else mob._data_defaults if color is not None: - mob.data[name][:, :3] = resize_array(rgbs, size) + rgbs = np.array([color_to_rgb(c) for c in listify(color)]) + data[name][:, :3] = resize_with_interpolation(rgbs, len(data[name])) if opacity is not None: - mob.data[name][:, 3] = resize_array(opacities, size) + opacities = np.array(listify(opacity)) + data[name][:, 3] = resize_with_interpolation(opacities, len(data[name])) return self def set_color( @@ -1869,7 +1874,7 @@ class Mobject(object): result.append(shader_wrapper) return result - def check_data_alignment(self, array: Iterable, data_key: str): + def check_data_alignment(self, array: np.ndarray, data_key: str): # Makes sure that self.data[key] can be broadcast into # the given array, meaning its length has to be either 1 # or the length of the array @@ -1895,7 +1900,7 @@ class Mobject(object): ): if data_key in self.locked_data_keys: return - self.check_data_alignment(shader_data, data_key) + self.check_data_alignment(shader_data, data_key) # TODO, make sure this can be removed shader_data[shader_data_key] = self.data[data_key] def get_shader_data(self): diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index f921677e..025ece0e 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -6,7 +6,7 @@ import numpy as np from manimlib.constants import GREY_C, YELLOW from manimlib.constants import ORIGIN, NULL_POINTS from manimlib.mobject.types.point_cloud_mobject import PMobject -from manimlib.utils.iterables import resize_preserving_order +from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING @@ -101,7 +101,7 @@ class DotCloud(PMobject): def set_radii(self, radii: npt.ArrayLike): n_points = len(self.get_points()) radii = np.array(radii).reshape((len(radii), 1)) - self.data["radii"] = resize_preserving_order(radii, n_points) + self.data["radii"][:] = resize_with_interpolation(radii, n_points) self.refresh_bounding_box() return self diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index da74dc5a..7b9489c5 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -8,6 +8,7 @@ from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.images import get_full_raster_image_path from manimlib.utils.iterables import listify +from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING @@ -39,7 +40,7 @@ class ImageMobject(Mobject): self.data = { "points": np.array([UL, DL, UR, DR]), "im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]), - "opacity": np.array([[self.opacity]], dtype=np.float32), + "opacity": self.opacity * np.ones((4, 1)), } def init_points(self) -> None: @@ -48,8 +49,9 @@ class ImageMobject(Mobject): self.set_height(self.height) def set_opacity(self, opacity: float, recurse: bool = True): + op_arr = np.array([[o] for o in listify(opacity)]) for mob in self.get_family(recurse): - mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) + mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points()) return self def set_color(self, color, opacity=None, recurse=None): diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 412bc7b8..40b3896d 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -16,17 +16,6 @@ if TYPE_CHECKING: class PMobject(Mobject): - def resize_points( - self, - size: int, - resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array - ): - # TODO - for key in self.data: - if len(self.data[key]) != size: - self.data[key] = resize_func(self.data[key], size) - return self - def set_points(self, points: Vect3Array): if len(points) == 0: points = np.zeros((0, 3)) @@ -64,7 +53,7 @@ class PMobject(Mobject): return self def set_color_by_gradient(self, *colors: ManimColor): - self.data["rgbas"] = np.array(list(map( + self.data["rgbas"][:] = np.array(list(map( color_to_rgba, color_gradient(colors, self.get_num_points()) ))) @@ -92,7 +81,7 @@ class PMobject(Mobject): np.apply_along_axis(function, 1, mob.get_points()) ) for key in mob.data: - mob.data[key] = mob.data[key][indices] + mob.data[key][:] = mob.data[key][indices] return self def ingest_submobjects(self): diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 8ecb09c4..3e726f1a 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -10,6 +10,7 @@ from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate from manimlib.utils.images import get_full_raster_image_path from manimlib.utils.iterables import listify +from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.space_ops import normalize_along_axis from typing import TYPE_CHECKING @@ -336,8 +337,9 @@ class TexturedSurface(Surface): self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) def set_opacity(self, opacity: float, recurse: bool = True): + op_arr = np.array([[o] for o in listify(opacity)]) for mob in self.get_family(recurse): - mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) + mob.data["opacity"][:] = resize_with_interpolation(op_arr, len(mob.data["opacity"])) return self def pointwise_become_partial( diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 4ad9cdec..33ddc356 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -29,6 +29,7 @@ from manimlib.utils.iterables import listify from manimlib.utils.iterables import make_even from manimlib.utils.iterables import resize_array from manimlib.utils.iterables import resize_with_interpolation +from manimlib.utils.iterables import arrays_match from manimlib.utils.space_ops import angle_between_vectors from manimlib.utils.space_ops import cross2d from manimlib.utils.space_ops import earclip_triangulation @@ -209,11 +210,11 @@ class VMobject(Mobject): if width is not None: for mob in self.get_family(recurse): - if isinstance(width, np.ndarray): - arr = width.reshape((len(width), 1)) - else: - arr = np.array([[w] for w in listify(width)], dtype=float) - mob.data['stroke_width'] = arr + data = mob.data if mob.get_num_points() > 0 else mob._data_defaults + data['stroke_width'][:, 0] = resize_with_interpolation( + np.array(listify(width)).flatten(), + len(data['stroke_width']) + ) if background is not None: for mob in self.get_family(recurse): @@ -252,7 +253,7 @@ class VMobject(Mobject): ): for mob in self.get_family(recurse): if fill_rgba is not None: - mob.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba)) + mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba'])) else: mob.set_fill( color=fill_color, @@ -261,7 +262,7 @@ class VMobject(Mobject): ) if stroke_rgba is not None: - mob.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(stroke_rgba)) + mob.data['stroke_rgba'][:] = resize_with_interpolation(stroke_rgba, len(mob.data['stroke_rgba'])) mob.set_stroke( width=stroke_width, background=stroke_background, @@ -926,7 +927,7 @@ class VMobject(Mobject): if self.has_fill(): tri1 = mobject1.get_triangulation() tri2 = mobject2.get_triangulation() - if len(tri1) != len(tri2) or not (tri1 == tri2).all(): + if not arrays_match(tri1, tri2): self.refresh_triangulation() return self @@ -991,10 +992,6 @@ class VMobject(Mobject): def refresh_triangulation(self): for mob in self.get_family(): mob.needs_new_triangulation = True - mob.data["orientation"] = resize_array( - mob.data["orientation"], - mob.get_num_points() - ) return self def get_triangulation(self): @@ -1023,7 +1020,6 @@ class VMobject(Mobject): curve_orientations = np.sign(cross2d(v01s, v12s)) # Reset orientation data - self.data["orientation"] = resize_array(self.data["orientation"], len(points)) self.data["orientation"][1::2, 0] = curve_orientations if "orientation" in self.locked_data_keys: self.locked_data_keys.remove("orientation") @@ -1072,7 +1068,6 @@ class VMobject(Mobject): self.needs_new_joint_angles = False points = self.get_points() - self.data["joint_angle"] = resize_array(self.data["joint_angle"], len(points)) if(len(points) < 3): return self.data["joint_angle"]