Keep track of when Mobject data has changed, and used that to determine when ShaderWrapper generates new buffers

This commit is contained in:
Grant Sanderson 2023-01-25 16:43:47 -08:00
parent 4dfabc1c28
commit d2af6a5f4b
9 changed files with 74 additions and 49 deletions

View file

@ -30,15 +30,6 @@ class ShowPartial(Animation, ABC):
self.should_match_start = should_match_start
super().__init__(mobject, **kwargs)
def begin(self) -> None:
super().begin()
if not self.should_match_start:
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def interpolate_submobject(
self,
submob: VMobject,
@ -114,11 +105,9 @@ class DrawBorderThenFill(Animation):
self.outline = self.get_outline()
super().begin()
self.mobject.match_style(self.outline)
self.mobject.lock_matching_data(self.mobject, self.outline)
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
self.mobject.refresh_joint_products()
def get_outline(self) -> VMobject:
@ -146,9 +135,6 @@ class DrawBorderThenFill(Animation):
if index == 1 and self.sm_to_index[hash(submob)] == 0:
# First time crossing over
submob.set_data(outline.data)
submob.unlock_data()
if not self.mobject.has_updaters:
submob.lock_matching_data(submob, start)
submob.needs_new_triangulation = False
self.sm_to_index[hash(submob)] = 1

View file

@ -143,11 +143,6 @@ class Mobject(object):
# Typically implemented in subclass, unlpess purposefully left blank
pass
def set_data(self, data: np.ndarray):
assert(data.dtype == self.data.dtype)
self.data = data
return self
def set_uniforms(self, uniforms: dict):
for key, value in uniforms.items():
if isinstance(value, np.ndarray):
@ -160,8 +155,36 @@ class Mobject(object):
# Borrowed from https://github.com/ManimCommunity/manim/
return _AnimationBuilder(self)
# Only these methods should directly affect points
def note_changed_data(self, recurse_up: bool = True):
self._data_has_changed = True
if recurse_up:
for mob in self.parents:
mob.note_changed_data()
def affects_data(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.note_changed_data()
return wrapper
def affects_family_data(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
for mob in self.family_members_with_points():
mob.note_changed_data()
return self
return wrapper
# Only these methods should directly affect points
@affects_data
def set_data(self, data: np.ndarray):
assert(data.dtype == self.data.dtype)
self.data = data
return self
@affects_data
def resize_points(
self,
new_length: int,
@ -177,11 +200,13 @@ class Mobject(object):
self.refresh_bounding_box()
return self
@affects_data
def set_points(self, points: Vect3Array):
self.resize_points(len(points), resize_func=resize_preserving_order)
self.data["point"][:] = points
return self
@affects_data
def append_points(self, new_points: Vect3Array):
n = self.get_num_points()
self.resize_points(n + len(new_points))
@ -192,11 +217,13 @@ class Mobject(object):
self.refresh_bounding_box()
return self
@affects_family_data
def reverse_points(self):
for mob in self.get_family():
mob.data = mob.data[::-1]
return self
@affects_family_data
def apply_points_function(
self,
func: Callable[[np.ndarray], np.ndarray],
@ -330,6 +357,7 @@ class Mobject(object):
def split(self) -> list[Mobject]:
return self.submobjects
@affects_data
def assemble_family(self):
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)]
@ -593,6 +621,7 @@ class Mobject(object):
# won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters)
result.time_based_updaters = list(self.time_based_updaters)
result._data_has_changed = True
family = self.get_family()
for attr, value in list(self.__dict__.items()):
@ -1216,6 +1245,7 @@ class Mobject(object):
# Color functions
@affects_family_data
def set_rgba_array(
self,
rgba_array: npt.ArrayLike,
@ -1254,6 +1284,7 @@ class Mobject(object):
mob.set_rgba_array(rgba_array)
return self
@affects_family_data
def set_rgba_array_by_color(
self,
color: ManimColor | Iterable[ManimColor] | None = None,
@ -1681,6 +1712,7 @@ class Mobject(object):
# Interpolate
@affects_data
def interpolate(
self,
mobject1: Mobject,
@ -1893,11 +1925,10 @@ class Mobject(object):
return self.shader_indices
def render(self, ctx: Context, camera_uniforms: dict):
if self._data_has_changed or self.is_changing():
if self._data_has_changed:
self.shader_wrappers = self.get_shader_wrapper_list(ctx)
for shader_wrapper in self.shader_wrappers:
shader_wrapper.release()
shader_wrapper.get_vao()
shader_wrapper.generate_vao()
self._data_has_changed = False
for shader_wrapper in self.shader_wrappers:
shader_wrapper.uniforms.update(self.get_uniforms())

View file

@ -5,6 +5,7 @@ import numpy as np
from manimlib.constants import GREY_C, YELLOW
from manimlib.constants import ORIGIN, NULL_POINTS
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.utils.iterables import resize_with_interpolation
@ -94,6 +95,7 @@ class DotCloud(PMobject):
self.center()
return self
@Mobject.affects_data
def set_radii(self, radii: npt.ArrayLike):
n_points = self.get_num_points()
radii = np.array(radii).reshape((len(radii), 1))
@ -104,6 +106,7 @@ class DotCloud(PMobject):
def get_radii(self) -> np.ndarray:
return self.data["radius"]
@Mobject.affects_data
def set_radius(self, radius: float):
data = self.data if self.get_num_points() > 0 else self._data_defaults
data["radius"][:] = radius

View file

@ -47,6 +47,7 @@ class ImageMobject(Mobject):
self.set_width(2 * size[0] / size[1], stretch=True)
self.set_height(self.height)
@Mobject.affects_data
def set_opacity(self, opacity: float, recurse: bool = True):
self.data["opacity"][:, 0] = resize_with_interpolation(
np.array(listify(opacity)),

View file

@ -5,7 +5,6 @@ import numpy as np
from manimlib.mobject.mobject import Mobject
from manimlib.utils.color import color_gradient
from manimlib.utils.color import color_to_rgba
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
@ -52,6 +51,7 @@ class PMobject(Mobject):
self.add_points([point], rgbas, color, opacity)
return self
@Mobject.affects_data
def set_color_by_gradient(self, *colors: ManimColor):
self.data["rgba"][:] = np.array(list(map(
color_to_rgba,
@ -59,17 +59,20 @@ class PMobject(Mobject):
)))
return self
@Mobject.affects_data
def match_colors(self, pmobject: PMobject):
self.data["rgba"][:] = resize_with_interpolation(
pmobject.data["rgba"], self.get_num_points()
)
return self
@Mobject.affects_data
def filter_out(self, condition: Callable[[np.ndarray], bool]):
for mob in self.family_members_with_points():
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self
@Mobject.affects_data
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
"""
function is any map from R^3 to R
@ -81,6 +84,7 @@ class PMobject(Mobject):
mob.data[:] = mob.data[indices]
return self
@Mobject.affects_data
def ingest_submobjects(self):
self.data = np.vstack([
sm.data for sm in self.get_family()
@ -91,6 +95,7 @@ class PMobject(Mobject):
index = alpha * (self.get_num_points() - 1)
return self.get_points()[int(index)]
@Mobject.affects_data
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())

View file

@ -72,6 +72,7 @@ class Surface(Mobject):
# To be implemented in subclasses
return (u, v, 0.0)
@Mobject.affects_data
def init_points(self):
dim = self.dim
nu, nv = self.resolution
@ -130,6 +131,7 @@ class Surface(Mobject):
)
return normalize_along_axis(normals, 1)
@Mobject.affects_data
def pointwise_become_partial(
self,
smobject: "Surface",
@ -298,6 +300,7 @@ class TexturedSurface(Surface):
**kwargs
)
@Mobject.affects_data
def init_points(self):
surf = self.uv_surface
nu, nv = surf.resolution
@ -315,6 +318,7 @@ class TexturedSurface(Surface):
super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures
@Mobject.affects_data
def set_opacity(self, opacity: float | Iterable[float]):
op_arr = np.array(listify(opacity))
self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))

View file

@ -233,6 +233,7 @@ class VMobject(Mobject):
self.set_stroke(color, width, background=background)
return self
@Mobject.affects_family_data
def set_style(
self,
fill_color: ManimColor | Iterable[ManimColor] | None = None,
@ -1071,6 +1072,7 @@ class VMobject(Mobject):
return self.data["joint_product"]
self.needs_new_joint_products = False
self._data_has_changed = True
points = self.get_points()
@ -1109,6 +1111,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject):
for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2)
def triggers_refreshed_triangulation(func: Callable):
@wraps(func)
def wrapper(self, *args, refresh=True, **kwargs):
@ -1119,10 +1126,11 @@ class VMobject(Mobject):
return self
return wrapper
@triggers_refreshed_triangulation
def set_points(self, points: Vect3Array):
assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points)
self.refresh_triangulation()
self.get_joint_products(refresh=True)
return self
@triggers_refreshed_triangulation

View file

@ -189,7 +189,6 @@ class Scene(object):
"Press `command + q` or `esc` to quit"
)
self.skip_animations = False
self.refresh_static_mobjects()
while not self.is_window_closing():
self.update_frame(1 / self.camera.fps)
@ -251,7 +250,6 @@ class Scene(object):
# Operation to run after each ipython command
def post_cell_func():
self.refresh_static_mobjects()
if not self.is_window_closing():
self.update_frame(dt=0, ignore_skipping=True)
self.save_state()
@ -562,8 +560,6 @@ class Scene(object):
self.real_animation_start_time = time.time()
self.virtual_animation_start_time = self.time
self.refresh_static_mobjects()
def post_play(self):
if not self.skip_animations:
self.file_writer.end_animation()
@ -574,10 +570,6 @@ class Scene(object):
self.num_plays += 1
def refresh_static_mobjects(self) -> None:
for mobject in self.mobjects:
mobject._data_has_changed = True
def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations:
animation.begin()
@ -652,7 +644,6 @@ class Scene(object):
self.emit_frame()
if stop_condition is not None and stop_condition():
break
self.refresh_static_mobjects()
self.post_play()
def hold_loop(self):
@ -712,13 +703,11 @@ class Scene(object):
if self.undo_stack:
self.redo_stack.append(self.get_state())
self.restore_state(self.undo_stack.pop())
self.refresh_static_mobjects()
def redo(self):
if self.redo_stack:
self.undo_stack.append(self.get_state())
self.restore_state(self.redo_stack.pop())
self.refresh_static_mobjects()
def checkpoint_paste(self, skip: bool = False, record: bool = False):
"""

View file

@ -51,16 +51,16 @@ class ShaderWrapper(object):
self.depth_test = depth_test
self.render_primitive = render_primitive
self.vbo = None
self.ibo = None
self.vao = None
self.init_program_code()
self.init_program()
if texture_paths is not None:
self.init_textures(texture_paths)
self.refresh_id()
self.vbo = None
self.ibo = None
self.vao = None
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
@ -100,15 +100,16 @@ class ShaderWrapper(object):
self.render_primitive == shader_wrapper.render_primitive,
))
def __del__(self):
self.release()
def copy(self):
result = copy.copy(self)
result.ctx = self.ctx
result.vert_data = self.vert_data.copy()
result.vert_indices = self.vert_indices.copy()
if self.uniforms:
result.uniforms = {key: np.array(value) for key, value in self.uniforms.items()}
result.vao = None
result.vbo = None
result.ibo = None
return result
def is_valid(self) -> bool:
@ -219,7 +220,6 @@ class ShaderWrapper(object):
self.update_program_uniforms()
def render(self):
# TODO, generate on the fly?
assert(self.vao is not None)
self.vao.render()
@ -242,7 +242,8 @@ class ShaderWrapper(object):
self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32))
return self.ibo
def get_vao(self, refresh: bool = True):
def generate_vao(self, refresh: bool = True):
self.release()
# Data buffer
vbo = self.get_vertex_buffer_object(refresh)
ibo = self.get_index_buffer_object(refresh)
@ -258,10 +259,7 @@ class ShaderWrapper(object):
def release(self):
for obj in (self.vbo, self.ibo, self.vao):
if obj is not None:
try:
obj.release()
except AttributeError:
pass
obj.release()
self.vbo = None
self.ibo = None
self.vao = None