mirror of
https://github.com/3b1b/manim.git
synced 2025-08-31 10:58:36 +00:00
Set the stage for data to be treated as a structure numpy array
This commit is contained in:
parent
f2e91ef66f
commit
286b8fb6c3
6 changed files with 43 additions and 50 deletions
|
@ -106,6 +106,10 @@ class Mobject(object):
|
|||
self.bounding_box: Vect3Array = np.zeros((3, 3))
|
||||
|
||||
self.init_data()
|
||||
self._data_defaults = {
|
||||
key: np.zeros((1, self.data[key].shape[1]))
|
||||
for key in self.data
|
||||
}
|
||||
self.init_uniforms()
|
||||
self.init_updaters()
|
||||
self.init_event_listners()
|
||||
|
@ -130,7 +134,7 @@ class Mobject(object):
|
|||
def init_data(self):
|
||||
self.data: dict[str, np.ndarray] = {
|
||||
"points": np.zeros((0, 3)),
|
||||
"rgbas": np.zeros((1, 4)),
|
||||
"rgbas": np.zeros((0, 4)),
|
||||
}
|
||||
|
||||
def init_uniforms(self):
|
||||
|
@ -172,7 +176,15 @@ class Mobject(object):
|
|||
new_length: int,
|
||||
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
|
||||
):
|
||||
for key in self.aligned_data_keys:
|
||||
if new_length == 0:
|
||||
for key in self.data:
|
||||
if len(self.data[key]) > 0:
|
||||
self._data_defaults[key][:1] = self.data[key][:1]
|
||||
elif self.get_num_points() == 0:
|
||||
for key in self.data:
|
||||
self.data[key] = self._data_defaults[key].copy()
|
||||
|
||||
for key in self.data:
|
||||
self.data[key] = resize_func(self.data[key], new_length)
|
||||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
@ -1215,7 +1227,8 @@ class Mobject(object):
|
|||
recurse: bool = False
|
||||
):
|
||||
for mob in self.get_family(recurse):
|
||||
mob.data[name] = np.array(rgba_array)
|
||||
data = mob.data if mob.get_num_points() > 0 else mob._data_defaults
|
||||
data[name][:] = rgba_array
|
||||
return self
|
||||
|
||||
def set_color_by_rgba_func(
|
||||
|
@ -1252,22 +1265,14 @@ class Mobject(object):
|
|||
name: str = "rgbas",
|
||||
recurse: bool = True
|
||||
):
|
||||
max_len = 0
|
||||
if color is not None:
|
||||
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
|
||||
max_len = len(rgbs)
|
||||
if opacity is not None:
|
||||
opacities = np.array(listify(opacity))
|
||||
max_len = max(max_len, len(opacities))
|
||||
|
||||
for mob in self.get_family(recurse):
|
||||
if max_len > len(mob.data[name]):
|
||||
mob.data[name] = resize_array(mob.data[name], max_len)
|
||||
size = len(mob.data[name])
|
||||
data = mob.data if mob.has_points() > 0 else mob._data_defaults
|
||||
if color is not None:
|
||||
mob.data[name][:, :3] = resize_array(rgbs, size)
|
||||
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
|
||||
data[name][:, :3] = resize_with_interpolation(rgbs, len(data[name]))
|
||||
if opacity is not None:
|
||||
mob.data[name][:, 3] = resize_array(opacities, size)
|
||||
opacities = np.array(listify(opacity))
|
||||
data[name][:, 3] = resize_with_interpolation(opacities, len(data[name]))
|
||||
return self
|
||||
|
||||
def set_color(
|
||||
|
@ -1869,7 +1874,7 @@ class Mobject(object):
|
|||
result.append(shader_wrapper)
|
||||
return result
|
||||
|
||||
def check_data_alignment(self, array: Iterable, data_key: str):
|
||||
def check_data_alignment(self, array: np.ndarray, data_key: str):
|
||||
# Makes sure that self.data[key] can be broadcast into
|
||||
# the given array, meaning its length has to be either 1
|
||||
# or the length of the array
|
||||
|
@ -1895,7 +1900,7 @@ class Mobject(object):
|
|||
):
|
||||
if data_key in self.locked_data_keys:
|
||||
return
|
||||
self.check_data_alignment(shader_data, data_key)
|
||||
self.check_data_alignment(shader_data, data_key) # TODO, make sure this can be removed
|
||||
shader_data[shader_data_key] = self.data[data_key]
|
||||
|
||||
def get_shader_data(self):
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from manimlib.constants import GREY_C, YELLOW
|
||||
from manimlib.constants import ORIGIN, NULL_POINTS
|
||||
from manimlib.mobject.types.point_cloud_mobject import PMobject
|
||||
from manimlib.utils.iterables import resize_preserving_order
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -101,7 +101,7 @@ class DotCloud(PMobject):
|
|||
def set_radii(self, radii: npt.ArrayLike):
|
||||
n_points = len(self.get_points())
|
||||
radii = np.array(radii).reshape((len(radii), 1))
|
||||
self.data["radii"] = resize_preserving_order(radii, n_points)
|
||||
self.data["radii"][:] = resize_with_interpolation(radii, n_points)
|
||||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from manimlib.mobject.mobject import Mobject
|
|||
from manimlib.utils.bezier import inverse_interpolate
|
||||
from manimlib.utils.images import get_full_raster_image_path
|
||||
from manimlib.utils.iterables import listify
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -39,7 +40,7 @@ class ImageMobject(Mobject):
|
|||
self.data = {
|
||||
"points": np.array([UL, DL, UR, DR]),
|
||||
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
|
||||
"opacity": np.array([[self.opacity]], dtype=np.float32),
|
||||
"opacity": self.opacity * np.ones((4, 1)),
|
||||
}
|
||||
|
||||
def init_points(self) -> None:
|
||||
|
@ -48,8 +49,9 @@ class ImageMobject(Mobject):
|
|||
self.set_height(self.height)
|
||||
|
||||
def set_opacity(self, opacity: float, recurse: bool = True):
|
||||
op_arr = np.array([[o] for o in listify(opacity)])
|
||||
for mob in self.get_family(recurse):
|
||||
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
|
||||
mob.data["opacity"][:] = resize_with_interpolation(op_arr, mob.get_num_points())
|
||||
return self
|
||||
|
||||
def set_color(self, color, opacity=None, recurse=None):
|
||||
|
|
|
@ -16,17 +16,6 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class PMobject(Mobject):
|
||||
def resize_points(
|
||||
self,
|
||||
size: int,
|
||||
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
|
||||
):
|
||||
# TODO
|
||||
for key in self.data:
|
||||
if len(self.data[key]) != size:
|
||||
self.data[key] = resize_func(self.data[key], size)
|
||||
return self
|
||||
|
||||
def set_points(self, points: Vect3Array):
|
||||
if len(points) == 0:
|
||||
points = np.zeros((0, 3))
|
||||
|
@ -64,7 +53,7 @@ class PMobject(Mobject):
|
|||
return self
|
||||
|
||||
def set_color_by_gradient(self, *colors: ManimColor):
|
||||
self.data["rgbas"] = np.array(list(map(
|
||||
self.data["rgbas"][:] = np.array(list(map(
|
||||
color_to_rgba,
|
||||
color_gradient(colors, self.get_num_points())
|
||||
)))
|
||||
|
@ -92,7 +81,7 @@ class PMobject(Mobject):
|
|||
np.apply_along_axis(function, 1, mob.get_points())
|
||||
)
|
||||
for key in mob.data:
|
||||
mob.data[key] = mob.data[key][indices]
|
||||
mob.data[key][:] = mob.data[key][indices]
|
||||
return self
|
||||
|
||||
def ingest_submobjects(self):
|
||||
|
|
|
@ -10,6 +10,7 @@ from manimlib.utils.bezier import integer_interpolate
|
|||
from manimlib.utils.bezier import interpolate
|
||||
from manimlib.utils.images import get_full_raster_image_path
|
||||
from manimlib.utils.iterables import listify
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
from manimlib.utils.space_ops import normalize_along_axis
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -336,8 +337,9 @@ class TexturedSurface(Surface):
|
|||
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
|
||||
|
||||
def set_opacity(self, opacity: float, recurse: bool = True):
|
||||
op_arr = np.array([[o] for o in listify(opacity)])
|
||||
for mob in self.get_family(recurse):
|
||||
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
|
||||
mob.data["opacity"][:] = resize_with_interpolation(op_arr, len(mob.data["opacity"]))
|
||||
return self
|
||||
|
||||
def pointwise_become_partial(
|
||||
|
|
|
@ -29,6 +29,7 @@ from manimlib.utils.iterables import listify
|
|||
from manimlib.utils.iterables import make_even
|
||||
from manimlib.utils.iterables import resize_array
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
from manimlib.utils.iterables import arrays_match
|
||||
from manimlib.utils.space_ops import angle_between_vectors
|
||||
from manimlib.utils.space_ops import cross2d
|
||||
from manimlib.utils.space_ops import earclip_triangulation
|
||||
|
@ -209,11 +210,11 @@ class VMobject(Mobject):
|
|||
|
||||
if width is not None:
|
||||
for mob in self.get_family(recurse):
|
||||
if isinstance(width, np.ndarray):
|
||||
arr = width.reshape((len(width), 1))
|
||||
else:
|
||||
arr = np.array([[w] for w in listify(width)], dtype=float)
|
||||
mob.data['stroke_width'] = arr
|
||||
data = mob.data if mob.get_num_points() > 0 else mob._data_defaults
|
||||
data['stroke_width'][:, 0] = resize_with_interpolation(
|
||||
np.array(listify(width)).flatten(),
|
||||
len(data['stroke_width'])
|
||||
)
|
||||
|
||||
if background is not None:
|
||||
for mob in self.get_family(recurse):
|
||||
|
@ -252,7 +253,7 @@ class VMobject(Mobject):
|
|||
):
|
||||
for mob in self.get_family(recurse):
|
||||
if fill_rgba is not None:
|
||||
mob.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
|
||||
mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba']))
|
||||
else:
|
||||
mob.set_fill(
|
||||
color=fill_color,
|
||||
|
@ -261,7 +262,7 @@ class VMobject(Mobject):
|
|||
)
|
||||
|
||||
if stroke_rgba is not None:
|
||||
mob.data['stroke_rgba'] = resize_with_interpolation(stroke_rgba, len(stroke_rgba))
|
||||
mob.data['stroke_rgba'][:] = resize_with_interpolation(stroke_rgba, len(mob.data['stroke_rgba']))
|
||||
mob.set_stroke(
|
||||
width=stroke_width,
|
||||
background=stroke_background,
|
||||
|
@ -926,7 +927,7 @@ class VMobject(Mobject):
|
|||
if self.has_fill():
|
||||
tri1 = mobject1.get_triangulation()
|
||||
tri2 = mobject2.get_triangulation()
|
||||
if len(tri1) != len(tri2) or not (tri1 == tri2).all():
|
||||
if not arrays_match(tri1, tri2):
|
||||
self.refresh_triangulation()
|
||||
return self
|
||||
|
||||
|
@ -991,10 +992,6 @@ class VMobject(Mobject):
|
|||
def refresh_triangulation(self):
|
||||
for mob in self.get_family():
|
||||
mob.needs_new_triangulation = True
|
||||
mob.data["orientation"] = resize_array(
|
||||
mob.data["orientation"],
|
||||
mob.get_num_points()
|
||||
)
|
||||
return self
|
||||
|
||||
def get_triangulation(self):
|
||||
|
@ -1023,7 +1020,6 @@ class VMobject(Mobject):
|
|||
curve_orientations = np.sign(cross2d(v01s, v12s))
|
||||
|
||||
# Reset orientation data
|
||||
self.data["orientation"] = resize_array(self.data["orientation"], len(points))
|
||||
self.data["orientation"][1::2, 0] = curve_orientations
|
||||
if "orientation" in self.locked_data_keys:
|
||||
self.locked_data_keys.remove("orientation")
|
||||
|
@ -1072,7 +1068,6 @@ class VMobject(Mobject):
|
|||
self.needs_new_joint_angles = False
|
||||
|
||||
points = self.get_points()
|
||||
self.data["joint_angle"] = resize_array(self.data["joint_angle"], len(points))
|
||||
|
||||
if(len(points) < 3):
|
||||
return self.data["joint_angle"]
|
||||
|
|
Loading…
Add table
Reference in a new issue