From d2af6a5f4bbbe03653b221b491b408138887fb8e Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Wed, 25 Jan 2023 16:43:47 -0800 Subject: [PATCH] Keep track of when Mobject data has changed, and used that to determine when ShaderWrapper generates new buffers --- manimlib/animation/creation.py | 14 ------ manimlib/mobject/mobject.py | 49 +++++++++++++++---- manimlib/mobject/types/dot_cloud.py | 3 ++ manimlib/mobject/types/image_mobject.py | 1 + manimlib/mobject/types/point_cloud_mobject.py | 7 ++- manimlib/mobject/types/surface.py | 4 ++ manimlib/mobject/types/vectorized_mobject.py | 10 +++- manimlib/scene/scene.py | 11 ----- manimlib/shader_wrapper.py | 24 +++++---- 9 files changed, 74 insertions(+), 49 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 8271eb2c..cec23ab0 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -30,15 +30,6 @@ class ShowPartial(Animation, ABC): self.should_match_start = should_match_start super().__init__(mobject, **kwargs) - def begin(self) -> None: - super().begin() - if not self.should_match_start: - self.mobject.lock_matching_data(self.mobject, self.starting_mobject) - - def finish(self) -> None: - super().finish() - self.mobject.unlock_data() - def interpolate_submobject( self, submob: VMobject, @@ -114,11 +105,9 @@ class DrawBorderThenFill(Animation): self.outline = self.get_outline() super().begin() self.mobject.match_style(self.outline) - self.mobject.lock_matching_data(self.mobject, self.outline) def finish(self) -> None: super().finish() - self.mobject.unlock_data() self.mobject.refresh_joint_products() def get_outline(self) -> VMobject: @@ -146,9 +135,6 @@ class DrawBorderThenFill(Animation): if index == 1 and self.sm_to_index[hash(submob)] == 0: # First time crossing over submob.set_data(outline.data) - submob.unlock_data() - if not self.mobject.has_updaters: - submob.lock_matching_data(submob, start) submob.needs_new_triangulation = False self.sm_to_index[hash(submob)] = 1 diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 93162bec..5c1b1576 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -143,11 +143,6 @@ class Mobject(object): # Typically implemented in subclass, unlpess purposefully left blank pass - def set_data(self, data: np.ndarray): - assert(data.dtype == self.data.dtype) - self.data = data - return self - def set_uniforms(self, uniforms: dict): for key, value in uniforms.items(): if isinstance(value, np.ndarray): @@ -160,8 +155,36 @@ class Mobject(object): # Borrowed from https://github.com/ManimCommunity/manim/ return _AnimationBuilder(self) - # Only these methods should directly affect points + def note_changed_data(self, recurse_up: bool = True): + self._data_has_changed = True + if recurse_up: + for mob in self.parents: + mob.note_changed_data() + def affects_data(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + self.note_changed_data() + return wrapper + + def affects_family_data(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + for mob in self.family_members_with_points(): + mob.note_changed_data() + return self + return wrapper + + # Only these methods should directly affect points + @affects_data + def set_data(self, data: np.ndarray): + assert(data.dtype == self.data.dtype) + self.data = data + return self + + @affects_data def resize_points( self, new_length: int, @@ -177,11 +200,13 @@ class Mobject(object): self.refresh_bounding_box() return self + @affects_data def set_points(self, points: Vect3Array): self.resize_points(len(points), resize_func=resize_preserving_order) self.data["point"][:] = points return self + @affects_data def append_points(self, new_points: Vect3Array): n = self.get_num_points() self.resize_points(n + len(new_points)) @@ -192,11 +217,13 @@ class Mobject(object): self.refresh_bounding_box() return self + @affects_family_data def reverse_points(self): for mob in self.get_family(): mob.data = mob.data[::-1] return self + @affects_family_data def apply_points_function( self, func: Callable[[np.ndarray], np.ndarray], @@ -330,6 +357,7 @@ class Mobject(object): def split(self) -> list[Mobject]: return self.submobjects + @affects_data def assemble_family(self): sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *it.chain(*sub_families)] @@ -593,6 +621,7 @@ class Mobject(object): # won't have changed, just directly match. result.non_time_updaters = list(self.non_time_updaters) result.time_based_updaters = list(self.time_based_updaters) + result._data_has_changed = True family = self.get_family() for attr, value in list(self.__dict__.items()): @@ -1216,6 +1245,7 @@ class Mobject(object): # Color functions + @affects_family_data def set_rgba_array( self, rgba_array: npt.ArrayLike, @@ -1254,6 +1284,7 @@ class Mobject(object): mob.set_rgba_array(rgba_array) return self + @affects_family_data def set_rgba_array_by_color( self, color: ManimColor | Iterable[ManimColor] | None = None, @@ -1681,6 +1712,7 @@ class Mobject(object): # Interpolate + @affects_data def interpolate( self, mobject1: Mobject, @@ -1893,11 +1925,10 @@ class Mobject(object): return self.shader_indices def render(self, ctx: Context, camera_uniforms: dict): - if self._data_has_changed or self.is_changing(): + if self._data_has_changed: self.shader_wrappers = self.get_shader_wrapper_list(ctx) for shader_wrapper in self.shader_wrappers: - shader_wrapper.release() - shader_wrapper.get_vao() + shader_wrapper.generate_vao() self._data_has_changed = False for shader_wrapper in self.shader_wrappers: shader_wrapper.uniforms.update(self.get_uniforms()) diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 613c6e0b..83aa9b07 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -5,6 +5,7 @@ import numpy as np from manimlib.constants import GREY_C, YELLOW from manimlib.constants import ORIGIN, NULL_POINTS +from manimlib.mobject.mobject import Mobject from manimlib.mobject.types.point_cloud_mobject import PMobject from manimlib.utils.iterables import resize_with_interpolation @@ -94,6 +95,7 @@ class DotCloud(PMobject): self.center() return self + @Mobject.affects_data def set_radii(self, radii: npt.ArrayLike): n_points = self.get_num_points() radii = np.array(radii).reshape((len(radii), 1)) @@ -104,6 +106,7 @@ class DotCloud(PMobject): def get_radii(self) -> np.ndarray: return self.data["radius"] + @Mobject.affects_data def set_radius(self, radius: float): data = self.data if self.get_num_points() > 0 else self._data_defaults data["radius"][:] = radius diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 8de5b10e..80efe5aa 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -47,6 +47,7 @@ class ImageMobject(Mobject): self.set_width(2 * size[0] / size[1], stretch=True) self.set_height(self.height) + @Mobject.affects_data def set_opacity(self, opacity: float, recurse: bool = True): self.data["opacity"][:, 0] = resize_with_interpolation( np.array(listify(opacity)), diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index ee1b384b..f6180b3f 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -5,7 +5,6 @@ import numpy as np from manimlib.mobject.mobject import Mobject from manimlib.utils.color import color_gradient from manimlib.utils.color import color_to_rgba -from manimlib.utils.iterables import resize_array from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING @@ -52,6 +51,7 @@ class PMobject(Mobject): self.add_points([point], rgbas, color, opacity) return self + @Mobject.affects_data def set_color_by_gradient(self, *colors: ManimColor): self.data["rgba"][:] = np.array(list(map( color_to_rgba, @@ -59,17 +59,20 @@ class PMobject(Mobject): ))) return self + @Mobject.affects_data def match_colors(self, pmobject: PMobject): self.data["rgba"][:] = resize_with_interpolation( pmobject.data["rgba"], self.get_num_points() ) return self + @Mobject.affects_data def filter_out(self, condition: Callable[[np.ndarray], bool]): for mob in self.family_members_with_points(): mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] return self + @Mobject.affects_data def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): """ function is any map from R^3 to R @@ -81,6 +84,7 @@ class PMobject(Mobject): mob.data[:] = mob.data[indices] return self + @Mobject.affects_data def ingest_submobjects(self): self.data = np.vstack([ sm.data for sm in self.get_family() @@ -91,6 +95,7 @@ class PMobject(Mobject): index = alpha * (self.get_num_points() - 1) return self.get_points()[int(index)] + @Mobject.affects_data def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 8b7e2ee3..81855bac 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -72,6 +72,7 @@ class Surface(Mobject): # To be implemented in subclasses return (u, v, 0.0) + @Mobject.affects_data def init_points(self): dim = self.dim nu, nv = self.resolution @@ -130,6 +131,7 @@ class Surface(Mobject): ) return normalize_along_axis(normals, 1) + @Mobject.affects_data def pointwise_become_partial( self, smobject: "Surface", @@ -298,6 +300,7 @@ class TexturedSurface(Surface): **kwargs ) + @Mobject.affects_data def init_points(self): surf = self.uv_surface nu, nv = surf.resolution @@ -315,6 +318,7 @@ class TexturedSurface(Surface): super().init_uniforms() self.uniforms["num_textures"] = self.num_textures + @Mobject.affects_data def set_opacity(self, opacity: float | Iterable[float]): op_arr = np.array(listify(opacity)) self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data)) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index c04577ed..b6cecca9 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -233,6 +233,7 @@ class VMobject(Mobject): self.set_stroke(color, width, background=background) return self + @Mobject.affects_family_data def set_style( self, fill_color: ManimColor | Iterable[ManimColor] | None = None, @@ -1071,6 +1072,7 @@ class VMobject(Mobject): return self.data["joint_product"] self.needs_new_joint_products = False + self._data_has_changed = True points = self.get_points() @@ -1109,6 +1111,11 @@ class VMobject(Mobject): self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) return self.data["joint_product"] + def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject): + for mob in [self, vmobject1, vmobject2]: + mob.get_joint_products() + super().lock_matching_data(vmobject1, vmobject2) + def triggers_refreshed_triangulation(func: Callable): @wraps(func) def wrapper(self, *args, refresh=True, **kwargs): @@ -1119,10 +1126,11 @@ class VMobject(Mobject): return self return wrapper - @triggers_refreshed_triangulation def set_points(self, points: Vect3Array): assert(len(points) == 0 or len(points) % 2 == 1) super().set_points(points) + self.refresh_triangulation() + self.get_joint_products(refresh=True) return self @triggers_refreshed_triangulation diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 53d33d8d..04b4951c 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -189,7 +189,6 @@ class Scene(object): "Press `command + q` or `esc` to quit" ) self.skip_animations = False - self.refresh_static_mobjects() while not self.is_window_closing(): self.update_frame(1 / self.camera.fps) @@ -251,7 +250,6 @@ class Scene(object): # Operation to run after each ipython command def post_cell_func(): - self.refresh_static_mobjects() if not self.is_window_closing(): self.update_frame(dt=0, ignore_skipping=True) self.save_state() @@ -562,8 +560,6 @@ class Scene(object): self.real_animation_start_time = time.time() self.virtual_animation_start_time = self.time - self.refresh_static_mobjects() - def post_play(self): if not self.skip_animations: self.file_writer.end_animation() @@ -574,10 +570,6 @@ class Scene(object): self.num_plays += 1 - def refresh_static_mobjects(self) -> None: - for mobject in self.mobjects: - mobject._data_has_changed = True - def begin_animations(self, animations: Iterable[Animation]) -> None: for animation in animations: animation.begin() @@ -652,7 +644,6 @@ class Scene(object): self.emit_frame() if stop_condition is not None and stop_condition(): break - self.refresh_static_mobjects() self.post_play() def hold_loop(self): @@ -712,13 +703,11 @@ class Scene(object): if self.undo_stack: self.redo_stack.append(self.get_state()) self.restore_state(self.undo_stack.pop()) - self.refresh_static_mobjects() def redo(self): if self.redo_stack: self.undo_stack.append(self.get_state()) self.restore_state(self.redo_stack.pop()) - self.refresh_static_mobjects() def checkpoint_paste(self, skip: bool = False, record: bool = False): """ diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index b7e311b1..c273e765 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -51,16 +51,16 @@ class ShaderWrapper(object): self.depth_test = depth_test self.render_primitive = render_primitive - self.vbo = None - self.ibo = None - self.vao = None - self.init_program_code() self.init_program() if texture_paths is not None: self.init_textures(texture_paths) self.refresh_id() + self.vbo = None + self.ibo = None + self.vao = None + def init_program_code(self) -> None: def get_code(name: str) -> str | None: return get_shader_code_from_file( @@ -100,15 +100,16 @@ class ShaderWrapper(object): self.render_primitive == shader_wrapper.render_primitive, )) - def __del__(self): - self.release() - def copy(self): result = copy.copy(self) + result.ctx = self.ctx result.vert_data = self.vert_data.copy() result.vert_indices = self.vert_indices.copy() if self.uniforms: result.uniforms = {key: np.array(value) for key, value in self.uniforms.items()} + result.vao = None + result.vbo = None + result.ibo = None return result def is_valid(self) -> bool: @@ -219,7 +220,6 @@ class ShaderWrapper(object): self.update_program_uniforms() def render(self): - # TODO, generate on the fly? assert(self.vao is not None) self.vao.render() @@ -242,7 +242,8 @@ class ShaderWrapper(object): self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32)) return self.ibo - def get_vao(self, refresh: bool = True): + def generate_vao(self, refresh: bool = True): + self.release() # Data buffer vbo = self.get_vertex_buffer_object(refresh) ibo = self.get_index_buffer_object(refresh) @@ -258,10 +259,7 @@ class ShaderWrapper(object): def release(self): for obj in (self.vbo, self.ibo, self.vao): if obj is not None: - try: - obj.release() - except AttributeError: - pass + obj.release() self.vbo = None self.ibo = None self.vao = None