Set the stage for data to be treated as a structure numpy array

This commit is contained in:
Grant Sanderson 2023-01-15 12:34:59 -08:00
parent f2e91ef66f
commit 286b8fb6c3
6 changed files with 43 additions and 50 deletions

View file

@ -106,6 +106,10 @@ class Mobject(object):
self.bounding_box: Vect3Array = np.zeros((3, 3))
self.init_data()
self._data_defaults = {
key: np.zeros((1, self.data[key].shape[1]))
for key in self.data
}
self.init_uniforms()
self.init_updaters()
self.init_event_listners()
@ -130,7 +134,7 @@ class Mobject(object):
def init_data(self):
self.data: dict[str, np.ndarray] = {
"points": np.zeros((0, 3)),
"rgbas": np.zeros((1, 4)),
"rgbas": np.zeros((0, 4)),
}
def init_uniforms(self):
@ -172,7 +176,15 @@ class Mobject(object):
new_length: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
for key in self.aligned_data_keys:
if new_length == 0:
for key in self.data:
if len(self.data[key]) > 0:
self._data_defaults[key][:1] = self.data[key][:1]
elif self.get_num_points() == 0:
for key in self.data:
self.data[key] = self._data_defaults[key].copy()
for key in self.data:
self.data[key] = resize_func(self.data[key], new_length)
self.refresh_bounding_box()
return self
@ -1215,7 +1227,8 @@ class Mobject(object):
recurse: bool = False
):
for mob in self.get_family(recurse):
mob.data[name] = np.array(rgba_array)
data = mob.data if mob.get_num_points() > 0 else mob._data_defaults
data[name][:] = rgba_array
return self
def set_color_by_rgba_func(
@ -1252,22 +1265,14 @@ class Mobject(object):
name: str = "rgbas",
recurse: bool = True
):
max_len = 0
if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
max_len = len(rgbs)
if opacity is not None:
opacities = np.array(listify(opacity))
max_len = max(max_len, len(opacities))
for mob in self.get_family(recurse):
if max_len > len(mob.data[name]):
mob.data[name] = resize_array(mob.data[name], max_len)
size = len(mob.data[name])
data = mob.data if mob.has_points() > 0 else mob._data_defaults
if color is not None:
mob.data[name][:, :3] = resize_array(rgbs, size)
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
data[name][:, :3] = resize_with_interpolation(rgbs, len(data[name]))
if opacity is not None:
mob.data[name][:, 3] = resize_array(opacities, size)
opacities = np.array(listify(opacity))
data[name][:, 3] = resize_with_interpolation(opacities, len(data[name]))
return self
def set_color(
@ -1869,7 +1874,7 @@ class Mobject(object):
result.append(shader_wrapper)
return result
def check_data_alignment(self, array: Iterable, data_key: str):
def check_data_alignment(self, array: np.ndarray, data_key: str):
# Makes sure that self.data[key] can be broadcast into
# the given array, meaning its length has to be either 1
# or the length of the array
@ -1895,7 +1900,7 @@ class Mobject(object):
):
if data_key in self.locked_data_keys:
return
self.check_data_alignment(shader_data, data_key)
self.check_data_alignment(shader_data, data_key) # TODO, make sure this can be removed
shader_data[shader_data_key] = self.data[data_key]
def get_shader_data(self):

View file

@ -6,7 +6,7 @@ import numpy as np
from manimlib.constants import GREY_C, YELLOW
from manimlib.constants import ORIGIN, NULL_POINTS
from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
@ -101,7 +101,7 @@ class DotCloud(PMobject):
def set_radii(self, radii: npt.ArrayLike):
n_points = len(self.get_points())
radii = np.array(radii).reshape((len(radii), 1))
self.data["radii"] = resize_preserving_order(radii, n_points)
self.data["radii"][:] = resize_with_interpolation(radii, n_points)
self.refresh_bounding_box()
return self

View file

@ -8,6 +8,7 @@ from manimlib.mobject.mobject import Mobject
from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.images import get_full_raster_image_path
from manimlib.utils.iterables import listify
from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
@ -39,7 +40,7 @@ class ImageMobject(Mobject):
self.data = {
"points": np.array([UL, DL, UR, DR]),
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
"opacity": np.array([[self.opacity]], dtype=np.float32),
"opacity": self.opacity * np.ones((4, 1)),
}
def init_points(self) -> None:
@ -48,8 +49,9 @@ class ImageMobject(Mobject):
self.set_height(self.height)
def set_opacity(self, opacity: float, recurse: bool = True):
op_arr = np.array([[o] for o in listify(opacity)])
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points())
return self
def set_color(self, color, opacity=None, recurse=None):

View file

@ -16,17 +16,6 @@ if TYPE_CHECKING:
class PMobject(Mobject):
def resize_points(
self,
size: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
# TODO
for key in self.data:
if len(self.data[key]) != size:
self.data[key] = resize_func(self.data[key], size)
return self
def set_points(self, points: Vect3Array):
if len(points) == 0:
points = np.zeros((0, 3))
@ -64,7 +53,7 @@ class PMobject(Mobject):
return self
def set_color_by_gradient(self, *colors: ManimColor):
self.data["rgbas"] = np.array(list(map(
self.data["rgbas"][:] = np.array(list(map(
color_to_rgba,
color_gradient(colors, self.get_num_points())
)))
@ -92,7 +81,7 @@ class PMobject(Mobject):
np.apply_along_axis(function, 1, mob.get_points())
)
for key in mob.data:
mob.data[key] = mob.data[key][indices]
mob.data[key][:] = mob.data[key][indices]
return self
def ingest_submobjects(self):

View file

@ -10,6 +10,7 @@ from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate
from manimlib.utils.images import get_full_raster_image_path
from manimlib.utils.iterables import listify
from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.space_ops import normalize_along_axis
from typing import TYPE_CHECKING
@ -336,8 +337,9 @@ class TexturedSurface(Surface):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity: float, recurse: bool = True):
op_arr = np.array([[o] for o in listify(opacity)])
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
mob.data["opacity"][:] = resize_with_interpolation(op_arr, len(mob.data["opacity"]))
return self
def pointwise_become_partial(

View file

@ -29,6 +29,7 @@ from manimlib.utils.iterables import listify
from manimlib.utils.iterables import make_even
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import arrays_match
from manimlib.utils.space_ops import angle_between_vectors
from manimlib.utils.space_ops import cross2d
from manimlib.utils.space_ops import earclip_triangulation
@ -209,11 +210,11 @@ class VMobject(Mobject):
if width is not None:
for mob in self.get_family(recurse):
if isinstance(width, np.ndarray):
arr = width.reshape((len(width), 1))
else:
arr = np.array([[w] for w in listify(width)], dtype=float)
mob.data['stroke_width'] = arr
data = mob.data if mob.get_num_points() > 0 else mob._data_defaults
data['stroke_width'][:, 0] = resize_with_interpolation(
np.array(listify(width)).flatten(),
len(data['stroke_width'])
)
if background is not None:
for mob in self.get_family(recurse):
@ -252,7 +253,7 @@ class VMobject(Mobject):
):
for mob in self.get_family(recurse):
if fill_rgba is not None:
mob.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
else:
mob.set_fill(
color=fill_color,
@ -261,7 +262,7 @@ class VMobject(Mobject):
)
if stroke_rgba is not None:
mob.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(stroke_rgba))
mob.data['stroke_rgba'][:] = resize_with_interpolation(stroke_rgba, len(mob.data['stroke_rgba']))
mob.set_stroke(
width=stroke_width,
background=stroke_background,
@ -926,7 +927,7 @@ class VMobject(Mobject):
if self.has_fill():
tri1 = mobject1.get_triangulation()
tri2 = mobject2.get_triangulation()
if len(tri1) != len(tri2) or not (tri1 == tri2).all():
if not arrays_match(tri1, tri2):
self.refresh_triangulation()
return self
@ -991,10 +992,6 @@ class VMobject(Mobject):
def refresh_triangulation(self):
for mob in self.get_family():
mob.needs_new_triangulation = True
mob.data["orientation"] = resize_array(
mob.data["orientation"],
mob.get_num_points()
)
return self
def get_triangulation(self):
@ -1023,7 +1020,6 @@ class VMobject(Mobject):
curve_orientations = np.sign(cross2d(v01s, v12s))
# Reset orientation data
self.data["orientation"] = resize_array(self.data["orientation"], len(points))
self.data["orientation"][1::2, 0] = curve_orientations
if "orientation" in self.locked_data_keys:
self.locked_data_keys.remove("orientation")
@ -1072,7 +1068,6 @@ class VMobject(Mobject):
self.needs_new_joint_angles = False
points = self.get_points()
self.data["joint_angle"] = resize_array(self.data["joint_angle"], len(points))
if(len(points) < 3):
return self.data["joint_angle"]