Rework DotCloud and PMobject for new data structure

This commit is contained in:
Grant Sanderson 2021-01-11 12:40:21 -10:00
parent a5dd08cca7
commit c408adeefa
2 changed files with 80 additions and 122 deletions

View file

@ -3,10 +3,8 @@ import moderngl
import numbers
from manimlib.constants import GREY_C
from manimlib.constants import ORIGIN
from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.mobject.geometry import DEFAULT_DOT_RADIUS
from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import resize_preserving_order
@ -14,6 +12,7 @@ class DotCloud(PMobject):
CONFIG = {
"color": GREY_C,
"opacity": 1,
"radii": DEFAULT_DOT_RADIUS,
"shader_folder": "true_dot",
"render_primitive": moderngl.POINTS,
"shader_dtype": [
@ -23,26 +22,28 @@ class DotCloud(PMobject):
],
}
def __init__(self, points=[[ORIGIN]], radii=DEFAULT_DOT_RADIUS, **kwargs):
def __init__(self, points=None, **kwargs):
super().__init__(**kwargs)
self.rgbas = np.zeros((len(points), 4))
self.radii = np.full((len(points), 1), radii)
self.points = np.array(points)
self.set_color(self.color)
if points:
self.set_points(points)
def set_points(self, points):
super().set_points(points)
self.radii = resize_preserving_order(self.radii, len(points))
return self
def init_data(self):
self.data = {
"points": np.zeros((1, 3)),
"rgbas": np.zeros((1, 4)),
"radii": np.zeros((1, 1))
}
self.set_radii(self.radii)
def set_points_by_grid(self, n_rows, n_cols, height=None, width=None):
# TODO, add buff/hbuff/vbuff args...
new_points = np.zeros((n_rows * n_cols, 3))
new_points[:, 0] = np.tile(range(n_cols), n_rows)
new_points[:, 1] = np.repeat(range(n_rows), n_cols)
new_points[:, 2] = 0
self.set_points(new_points)
radius = self.radii[0]
radius = self.data["radii"].max()
if height is None:
height = n_rows * 3 * radius
if width is None:
@ -55,26 +56,19 @@ class DotCloud(PMobject):
return self
def set_radii(self, radii):
if isinstance(radii, numbers.Number):
self.radii[:] = radii
else:
self.radii = resize_preserving_order(radii, len(self.points))
if not isinstance(radii, numbers.Number):
radii = resize_preserving_order(radii, self.get_num_points())
self.data["radii"][:, 0] = radii
return self
def scale(self, scale_factor, scale_radii=True, **kwargs):
super().scale(scale_factor, **kwargs)
if scale_radii:
self.radii *= scale_factor
return self
def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs):
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
self.radii = interpolate(mobject1.radii, mobject2.radii, alpha)
self.data["radii"] *= scale_factor
return self
def get_shader_data(self):
data = self.get_blank_shader_data_array(len(self.points))
data["point"] = self.points
data["radius"] = self.radii.reshape((len(self.radii), 1))
data["color"] = self.rgbas
data = super().get_shader_data()
data["radius"] = self.data["radii"]
data["color"] = self.data["rgbas"]
return data

View file

@ -1,117 +1,87 @@
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.utils.bezier import interpolate
from manimlib.utils.color import color_gradient
from manimlib.utils.color import color_to_rgba
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgba_to_color
from manimlib.utils.iterables import resize_preserving_order
from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import resize_array
class PMobject(Mobject):
def reset_points(self):
self.rgbas = np.zeros((0, 4))
self.points = np.zeros((0, 3))
CONFIG = {
"opacity": 1.0,
}
def init_data(self):
self.data = {
"points": np.zeros((0, 3)),
"rgbas": np.zeros((0, 4)),
}
def init_colors(self):
self.set_color(self.color, self.opacity)
def resize_points(self, size, resize_func=resize_array):
for key in self.data:
if len(self.data[key]) != size:
self.data[key] = resize_array(self.data[key], size)
return self
def get_array_attrs(self):
return Mobject.get_array_attrs(self) + ["rgbas"]
def set_points(self, points):
self.points = points
self.rgbas = resize_preserving_order(self.rgbas, len(points))
return self
def add_points(self, points, rgbas=None, color=None, alpha=1):
def add_points(self, points, rgbas=None, color=None, opacity=None):
"""
points must be a Nx3 numpy array, as must rgbas if it is not None
"""
if not isinstance(points, np.ndarray):
points = np.array(points)
num_new_points = len(points)
self.points = np.vstack([self.points, points])
if rgbas is None:
color = Color(color) if color else self.color
self.append_points(points)
if color is not None:
if opacity is None:
opacity = self.data["rgbas"][-1, 3]
rgbas = np.repeat(
[color_to_rgba(color, alpha)],
num_new_points,
[color_to_rgba(color, opacity)],
len(points),
axis=0
)
elif len(rgbas) != len(points):
raise Exception("points and rgbas must have same shape")
self.rgbas = np.vstack([self.rgbas, rgbas])
elif rgbas is not None:
self.data["rgbas"][-len(rgbas):] = rgbas
return self
def set_color(self, color, family=True):
rgba = color_to_rgba(color)
mobs = self.family_members_with_points() if family else [self]
def set_color(self, color, opacity=None, family=True):
rgb = color_to_rgb(color)
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.rgbas[:, :] = rgba
mob.data["rgbas"][:, :3] = rgb
if opacity is not None:
self.set_opacity(opacity)
return self
def set_opacity(self, opacity, family=True):
mobs = self.family_members_with_points() if family else [self]
mobs = self.get_family() if family else [self]
for mob in mobs:
mob.rgbas[:, 3] = opacity
mob.data["rgbas"][:, 3] = opacity
return self
def get_color(self):
return rgba_to_color(self.rgbas[0, :])
return rgba_to_color(self.data["rgbas"][0])
def get_all_rgbas(self):
return self.get_merged_array("rgbas")
# def set_color_by_gradient(self, start_color, end_color):
def set_color_by_gradient(self, *colors):
self.rgbas = np.array(list(map(
self.data["rgbas"] = np.array(list(map(
color_to_rgba,
color_gradient(colors, len(self.points))
)))
return self
start_rgba, end_rgba = list(map(color_to_rgba, [start_color, end_color]))
for mob in self.family_members_with_points():
num_points = mob.get_num_points()
mob.rgbas = np.array([
interpolate(start_rgba, end_rgba, alpha)
for alpha in np.arange(num_points) / float(num_points)
])
return self
def set_colors_by_radial_gradient(self, center=None, radius=1, inner_color=WHITE, outer_color=BLACK):
start_rgba, end_rgba = list(map(color_to_rgba, [start_color, end_color]))
if center is None:
center = self.get_center()
for mob in self.family_members_with_points():
num_points = mob.get_num_points()
t = min(1, np.abs(mob.get_center() - center) / radius)
mob.rgbas = np.array(
[interpolate(start_rgba, end_rgba, t)] * num_points
)
return self
def match_colors(self, pmobject):
self.rgbas[:] = resize_preserving_order(pmobject.rgbas, len(self.points))
self.data["rgbas"][:] = resize_with_interpolation(
pmobject.data["rgbas"], self.get_num_points()
)
return self
def filter_out(self, condition):
for mob in self.family_members_with_points():
to_eliminate = ~np.apply_along_axis(condition, 1, mob.points)
mob.points = mob.points[to_eliminate]
mob.rgbas = mob.rgbas[to_eliminate]
return self
def thin_out(self, factor=5):
"""
Removes all but every nth point for n = factor
"""
for mob in self.family_members_with_points():
num_points = self.get_num_points()
mob.apply_over_attr_arrays(
lambda arr: arr[
np.arange(0, num_points, factor)
]
)
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
for key in mob.data:
mob.data[key] = mob.data[key][to_keep]
return self
def sort_points(self, function=lambda p: p[0]):
@ -120,43 +90,37 @@ class PMobject(Mobject):
"""
for mob in self.family_members_with_points():
indices = np.argsort(
np.apply_along_axis(function, 1, mob.points)
np.apply_along_axis(function, 1, mob.get_points())
)
mob.apply_over_attr_arrays(lambda arr: arr[indices])
for key in mob.data:
mob.data[key] = mob.data[key][indices]
return self
def ingest_submobjects(self):
attrs = self.get_array_attrs()
arrays = list(map(self.get_merged_array, attrs))
for attr, array in zip(attrs, arrays):
setattr(self, attr, array)
self.set_submobjects([])
for key in self.data:
self.data[key] = np.vstack([
sm.data[key]
for sm in self.get_family()
])
return self
def point_from_proportion(self, alpha):
index = alpha * (self.get_num_points() - 1)
return self.points[index]
# Alignment
def interpolate_color(self, mobject1, mobject2, alpha):
self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha)
return self
return self.get_points()[int(index)]
def pointwise_become_partial(self, pmobject, a, b):
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())
for attr in self.get_array_attrs():
full_array = getattr(pmobject, attr)
partial_array = full_array[lower_index:upper_index]
setattr(self, attr, partial_array)
for key in self.data:
self.data[key] = pmobject.data[key][lower_index:upper_index]
return self
class PGroup(PMobject):
def __init__(self, *pmobs, **kwargs):
if not all([isinstance(m, PMobject) for m in pmobs]):
raise Exception("All submobjects must be of type PMobject")
super().__init__(**kwargs)
self.add(*pmobs)
super().__init__(*pmobs, **kwargs)
class Point(PMobject):
@ -165,5 +129,5 @@ class Point(PMobject):
}
def __init__(self, location=ORIGIN, **kwargs):
PMobject.__init__(self, **kwargs)
super().__init__(**kwargs)
self.add_points([location])