diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 14ca6db2..2b72c084 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -66,7 +66,12 @@ class Mobject(object): # Must match in attributes of vert shader shader_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ ('point', np.float32, (3,)), + ('rgba', np.float32, (4,)), ] + data_dtype: np.dtype = np.dtype([ + ('points', ' 0: - self._data_defaults[key][:1] = self.data[key][:1] + if len(self.data) > 0: + self._data_defaults[:1] = self.data[:1] elif self.get_num_points() == 0: - for key in self.data: - self.data[key] = self._data_defaults[key].copy() + self.data = self._data_defaults.copy() - for key in self.data: - self.data[key] = resize_func(self.data[key], new_length) + self.data = resize_func(self.data, new_length) self.refresh_bounding_box() return self @@ -203,8 +199,7 @@ class Mobject(object): def reverse_points(self): for mob in self.get_family(): - for key in mob.data: - mob.data[key] = mob.data[key][::-1] + mob.data = mob.data[::-1] return self def apply_points_function( @@ -584,10 +579,7 @@ class Mobject(object): # The line above is only a shallow copy, so the internal # data which are numpyu arrays or other mobjects still # need to be further copied. - result.data = { - key: np.array(value) - for key, value in self.data.items() - } + result.data = self.data.copy() result.uniforms = { key: np.array(value) for key, value in self.uniforms.items() @@ -678,15 +670,22 @@ class Mobject(object): if len(fam1) != len(fam2): return False for m1, m2 in zip(fam1, fam2): - for d1, d2 in [(m1.data, m2.data), (m1.uniforms, m2.uniforms)]: - if set(d1).difference(d2): + if m1.get_num_points() != m2.get_num_points(): + return False + if not m1.data.dtype == m2.data.dtype: + return False + for key in m1.data.dtype.names: + if not np.isclose(m1.data[key], m2.data[key]).all(): + return False + if set(m1.uniforms).difference(m2.uniforms): + return False + for key in m1.uniforms: + value1 = m1.uniforms[key] + value2 = m2.uniforms[key] + if isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray) and not value1.size == value2.size: + return False + if not np.isclose(value1, value2).all(): return False - for key in d1: - if isinstance(d1[key], np.ndarray) and isinstance(d2[key], np.ndarray): - if not d1[key].size == d2[key].size: - return False - if not np.isclose(d1[key], d2[key]).all(): - return False return True def has_same_shape_as(self, mobject: Mobject) -> bool: @@ -1604,19 +1603,7 @@ class Mobject(object): # In case any data arrays get resized when aligned to shader data mob1.refresh_shader_data() mob2.refresh_shader_data() - mob1.align_points(mob2) - for key in mob1.data.keys() & mob2.data.keys(): - if key == "points": - # Separate out how points are treated so that subclasses - # can handle that case differently if they choose - continue - arr1 = mob1.data[key] - arr2 = mob2.data[key] - if len(arr2) > len(arr1): - mob1.data[key] = resize_preserving_order(arr1, len(arr2)) - elif len(arr1) > len(arr2): - mob2.data[key] = resize_preserving_order(arr2, len(arr1)) def align_points(self, mobject: Mobject): max_len = max(self.get_num_points(), mobject.get_num_points()) @@ -1686,13 +1673,11 @@ class Mobject(object): alpha: float, path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path ): - for key in self.data: + for key in self.data.dtype.names: if key in self.locked_data_keys: continue if len(self.data[key]) == 0: continue - if key not in mobject1.data or key not in mobject2.data: - continue func = path_func if key == "points" else interpolate @@ -1739,11 +1724,11 @@ class Mobject(object): def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject): for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()): - keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() - sm.lock_data(list(filter( - lambda key: arrays_match(sm1.data[key], sm2.data[key]), - keys, - ))) + if not (sm.data.dtype == sm1.data.dtype == sm2.data.dtype): + sm.lock_data([ + key for key in sm.data.dtype.names + if arrays_match(sm1.data[key], sm2.data[key]) + ]) return self def unlock_data(self): diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 025ece0e..11be87a0 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -29,7 +29,11 @@ class DotCloud(PMobject): ('radius', np.float32, (1,)), ('color', np.float32, (4,)), ] - + data_dtype: np.dtype = np.dtype([ + ('points', np.float32, (3,)), + ('radii', np.float32, (1,)), + ('rgbas', np.float32, (4,)), + ]) def __init__( self, points: Vect3Array = NULL_POINTS, @@ -55,7 +59,6 @@ class DotCloud(PMobject): def init_data(self) -> None: super().init_data() - self.data["radii"] = np.zeros((1, 1)) self.set_radius(self.radius) def init_uniforms(self) -> None: diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 7b9489c5..12b910f1 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -24,6 +24,11 @@ class ImageMobject(Mobject): ('im_coords', np.float32, (2,)), ('opacity', np.float32, (1,)), ] + data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ + ('points', np.float32, (3,)), + ('im_coords', np.float32, (2,)), + ('opacity', np.float32, (1,)), + ] def __init__( self, @@ -37,11 +42,10 @@ class ImageMobject(Mobject): super().__init__(texture_paths={"Texture": self.image_path}, **kwargs) def init_data(self) -> None: - self.data = { - "points": np.array([UL, DL, UR, DR]), - "im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]), - "opacity": self.opacity * np.ones((4, 1)), - } + super().init_data(length=4) + self.data["points"][:] = [UL, DL, UR, DR] + self.data["im_coords"][:] = [(0, 0), (0, 1), (1, 0), (1, 1)] + self.data["opacity"][:] = self.opacity def init_points(self) -> None: size = self.image.size @@ -49,9 +53,10 @@ class ImageMobject(Mobject): self.set_height(self.height) def set_opacity(self, opacity: float, recurse: bool = True): - op_arr = np.array([[o] for o in listify(opacity)]) - for mob in self.get_family(recurse): - mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points()) + self.data["opacity"][:, 0] = resize_with_interpolation( + np.array(listify(opacity)), + self.get_num_points() + ) return self def set_color(self, color, opacity=None, recurse=None): diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 40b3896d..4bb9cb5d 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -67,9 +67,7 @@ class PMobject(Mobject): def filter_out(self, condition: Callable[[np.ndarray], bool]): for mob in self.family_members_with_points(): - to_keep = ~np.apply_along_axis(condition, 1, mob.get_points()) - for key in mob.data: - mob.data[key] = mob.data[key][to_keep] + mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] return self def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): @@ -80,16 +78,13 @@ class PMobject(Mobject): indices = np.argsort( np.apply_along_axis(function, 1, mob.get_points()) ) - for key in mob.data: - mob.data[key][:] = mob.data[key][indices] + mob.data[:] = mob.data[indices] return self def ingest_submobjects(self): - for key in self.data: - self.data[key] = np.vstack([ - sm.data[key] - for sm in self.get_family() - ]) + self.data = np.vstack([ + sm.data for sm in self.get_family() + ]) return self def point_from_proportion(self, alpha: float) -> np.ndarray: @@ -99,8 +94,7 @@ class PMobject(Mobject): def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) - for key in self.data: - self.data[key] = pmobject.data[key][lower_index:upper_index].copy() + self.data = pmobject.data[lower_index:upper_index].copy() return self diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 33ddc356..9cdcfdcd 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -66,9 +66,16 @@ class VMobject(Mobject): ("stroke_width", np.float32, (1,)), ("color", np.float32, (4,)), ] + data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [ + ("points", np.float32, (3,)), + ('fill_rgba', np.float32, (4,)), + ("stroke_rgba", np.float32, (4,)), + ("joint_angle", np.float32, (1,)), + ("stroke_width", np.float32, (1,)), + ('orientation', np.float32, (1,)), + ] fill_render_primitive: int = moderngl.TRIANGLES stroke_render_primitive: int = moderngl.TRIANGLE_STRIP - aligned_data_keys = ["points", "orientation", "joint_angle"] pre_function_handle_to_anchor_scale_factor: float = 0.01 make_smooth_after_applying_functions: bool = False @@ -117,17 +124,6 @@ class VMobject(Mobject): def get_group_class(self): return VGroup - def init_data(self): - super().init_data() - self.data.pop("rgbas") - self.data.update({ - "fill_rgba": np.zeros((1, 4)), - "stroke_rgba": np.zeros((1, 4)), - "stroke_width": np.zeros((1, 1)), - "orientation": np.ones((1, 1)), - "joint_angle": np.zeros((0, 1)), - }) - def init_uniforms(self): super().init_uniforms() self.uniforms["anti_alias_width"] = self.anti_alias_width @@ -371,23 +367,28 @@ class VMobject(Mobject): If there are multiple colors (for gradient) this returns the first one """ - return self.get_fill_colors()[0] + data = self.data if self.has_points() else self._data_defaults + return rgb_to_hex(data["fill_rgba"][0, :3]) def get_fill_opacity(self) -> float: """ If there are multiple opacities, this returns the first """ - return self.get_fill_opacities()[0] + data = self.data if self.has_points() else self._data_defaults + return data["fill_rgba"][0, 3] def get_stroke_color(self) -> str: - return self.get_stroke_colors()[0] + data = self.data if self.has_points() else self._data_defaults + return rgb_to_hex(data["stroke_rgba"][0, :3]) def get_stroke_width(self) -> float | np.ndarray: - return self.get_stroke_widths()[0] + data = self.data if self.has_points() else self._data_defaults + return data["stroke_width"][0, 0] def get_stroke_opacity(self) -> float: - return self.get_stroke_opacities()[0] + data = self.data if self.has_points() else self._data_defaults + return data["stroke_rgba"][0, 3] def get_color(self) -> str: if self.has_fill(): @@ -1134,7 +1135,7 @@ class VMobject(Mobject): return self @triggers_refreshed_triangulation - def set_data(self, data: dict): + def set_data(self, data: np.ndarray): super().set_data(data) return self