mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Keep track of when Mobject data has changed, and used that to determine when ShaderWrapper generates new buffers
This commit is contained in:
parent
4dfabc1c28
commit
d2af6a5f4b
9 changed files with 74 additions and 49 deletions
|
@ -30,15 +30,6 @@ class ShowPartial(Animation, ABC):
|
|||
self.should_match_start = should_match_start
|
||||
super().__init__(mobject, **kwargs)
|
||||
|
||||
def begin(self) -> None:
|
||||
super().begin()
|
||||
if not self.should_match_start:
|
||||
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
|
||||
|
||||
def finish(self) -> None:
|
||||
super().finish()
|
||||
self.mobject.unlock_data()
|
||||
|
||||
def interpolate_submobject(
|
||||
self,
|
||||
submob: VMobject,
|
||||
|
@ -114,11 +105,9 @@ class DrawBorderThenFill(Animation):
|
|||
self.outline = self.get_outline()
|
||||
super().begin()
|
||||
self.mobject.match_style(self.outline)
|
||||
self.mobject.lock_matching_data(self.mobject, self.outline)
|
||||
|
||||
def finish(self) -> None:
|
||||
super().finish()
|
||||
self.mobject.unlock_data()
|
||||
self.mobject.refresh_joint_products()
|
||||
|
||||
def get_outline(self) -> VMobject:
|
||||
|
@ -146,9 +135,6 @@ class DrawBorderThenFill(Animation):
|
|||
if index == 1 and self.sm_to_index[hash(submob)] == 0:
|
||||
# First time crossing over
|
||||
submob.set_data(outline.data)
|
||||
submob.unlock_data()
|
||||
if not self.mobject.has_updaters:
|
||||
submob.lock_matching_data(submob, start)
|
||||
submob.needs_new_triangulation = False
|
||||
self.sm_to_index[hash(submob)] = 1
|
||||
|
||||
|
|
|
@ -143,11 +143,6 @@ class Mobject(object):
|
|||
# Typically implemented in subclass, unlpess purposefully left blank
|
||||
pass
|
||||
|
||||
def set_data(self, data: np.ndarray):
|
||||
assert(data.dtype == self.data.dtype)
|
||||
self.data = data
|
||||
return self
|
||||
|
||||
def set_uniforms(self, uniforms: dict):
|
||||
for key, value in uniforms.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
|
@ -160,8 +155,36 @@ class Mobject(object):
|
|||
# Borrowed from https://github.com/ManimCommunity/manim/
|
||||
return _AnimationBuilder(self)
|
||||
|
||||
# Only these methods should directly affect points
|
||||
def note_changed_data(self, recurse_up: bool = True):
|
||||
self._data_has_changed = True
|
||||
if recurse_up:
|
||||
for mob in self.parents:
|
||||
mob.note_changed_data()
|
||||
|
||||
def affects_data(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
func(self, *args, **kwargs)
|
||||
self.note_changed_data()
|
||||
return wrapper
|
||||
|
||||
def affects_family_data(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
func(self, *args, **kwargs)
|
||||
for mob in self.family_members_with_points():
|
||||
mob.note_changed_data()
|
||||
return self
|
||||
return wrapper
|
||||
|
||||
# Only these methods should directly affect points
|
||||
@affects_data
|
||||
def set_data(self, data: np.ndarray):
|
||||
assert(data.dtype == self.data.dtype)
|
||||
self.data = data
|
||||
return self
|
||||
|
||||
@affects_data
|
||||
def resize_points(
|
||||
self,
|
||||
new_length: int,
|
||||
|
@ -177,11 +200,13 @@ class Mobject(object):
|
|||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
||||
@affects_data
|
||||
def set_points(self, points: Vect3Array):
|
||||
self.resize_points(len(points), resize_func=resize_preserving_order)
|
||||
self.data["point"][:] = points
|
||||
return self
|
||||
|
||||
@affects_data
|
||||
def append_points(self, new_points: Vect3Array):
|
||||
n = self.get_num_points()
|
||||
self.resize_points(n + len(new_points))
|
||||
|
@ -192,11 +217,13 @@ class Mobject(object):
|
|||
self.refresh_bounding_box()
|
||||
return self
|
||||
|
||||
@affects_family_data
|
||||
def reverse_points(self):
|
||||
for mob in self.get_family():
|
||||
mob.data = mob.data[::-1]
|
||||
return self
|
||||
|
||||
@affects_family_data
|
||||
def apply_points_function(
|
||||
self,
|
||||
func: Callable[[np.ndarray], np.ndarray],
|
||||
|
@ -330,6 +357,7 @@ class Mobject(object):
|
|||
def split(self) -> list[Mobject]:
|
||||
return self.submobjects
|
||||
|
||||
@affects_data
|
||||
def assemble_family(self):
|
||||
sub_families = (sm.get_family() for sm in self.submobjects)
|
||||
self.family = [self, *it.chain(*sub_families)]
|
||||
|
@ -593,6 +621,7 @@ class Mobject(object):
|
|||
# won't have changed, just directly match.
|
||||
result.non_time_updaters = list(self.non_time_updaters)
|
||||
result.time_based_updaters = list(self.time_based_updaters)
|
||||
result._data_has_changed = True
|
||||
|
||||
family = self.get_family()
|
||||
for attr, value in list(self.__dict__.items()):
|
||||
|
@ -1216,6 +1245,7 @@ class Mobject(object):
|
|||
|
||||
# Color functions
|
||||
|
||||
@affects_family_data
|
||||
def set_rgba_array(
|
||||
self,
|
||||
rgba_array: npt.ArrayLike,
|
||||
|
@ -1254,6 +1284,7 @@ class Mobject(object):
|
|||
mob.set_rgba_array(rgba_array)
|
||||
return self
|
||||
|
||||
@affects_family_data
|
||||
def set_rgba_array_by_color(
|
||||
self,
|
||||
color: ManimColor | Iterable[ManimColor] | None = None,
|
||||
|
@ -1681,6 +1712,7 @@ class Mobject(object):
|
|||
|
||||
# Interpolate
|
||||
|
||||
@affects_data
|
||||
def interpolate(
|
||||
self,
|
||||
mobject1: Mobject,
|
||||
|
@ -1893,11 +1925,10 @@ class Mobject(object):
|
|||
return self.shader_indices
|
||||
|
||||
def render(self, ctx: Context, camera_uniforms: dict):
|
||||
if self._data_has_changed or self.is_changing():
|
||||
if self._data_has_changed:
|
||||
self.shader_wrappers = self.get_shader_wrapper_list(ctx)
|
||||
for shader_wrapper in self.shader_wrappers:
|
||||
shader_wrapper.release()
|
||||
shader_wrapper.get_vao()
|
||||
shader_wrapper.generate_vao()
|
||||
self._data_has_changed = False
|
||||
for shader_wrapper in self.shader_wrappers:
|
||||
shader_wrapper.uniforms.update(self.get_uniforms())
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
|
||||
from manimlib.constants import GREY_C, YELLOW
|
||||
from manimlib.constants import ORIGIN, NULL_POINTS
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
from manimlib.mobject.types.point_cloud_mobject import PMobject
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
|
||||
|
@ -94,6 +95,7 @@ class DotCloud(PMobject):
|
|||
self.center()
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def set_radii(self, radii: npt.ArrayLike):
|
||||
n_points = self.get_num_points()
|
||||
radii = np.array(radii).reshape((len(radii), 1))
|
||||
|
@ -104,6 +106,7 @@ class DotCloud(PMobject):
|
|||
def get_radii(self) -> np.ndarray:
|
||||
return self.data["radius"]
|
||||
|
||||
@Mobject.affects_data
|
||||
def set_radius(self, radius: float):
|
||||
data = self.data if self.get_num_points() > 0 else self._data_defaults
|
||||
data["radius"][:] = radius
|
||||
|
|
|
@ -47,6 +47,7 @@ class ImageMobject(Mobject):
|
|||
self.set_width(2 * size[0] / size[1], stretch=True)
|
||||
self.set_height(self.height)
|
||||
|
||||
@Mobject.affects_data
|
||||
def set_opacity(self, opacity: float, recurse: bool = True):
|
||||
self.data["opacity"][:, 0] = resize_with_interpolation(
|
||||
np.array(listify(opacity)),
|
||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
|||
from manimlib.mobject.mobject import Mobject
|
||||
from manimlib.utils.color import color_gradient
|
||||
from manimlib.utils.color import color_to_rgba
|
||||
from manimlib.utils.iterables import resize_array
|
||||
from manimlib.utils.iterables import resize_with_interpolation
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -52,6 +51,7 @@ class PMobject(Mobject):
|
|||
self.add_points([point], rgbas, color, opacity)
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def set_color_by_gradient(self, *colors: ManimColor):
|
||||
self.data["rgba"][:] = np.array(list(map(
|
||||
color_to_rgba,
|
||||
|
@ -59,17 +59,20 @@ class PMobject(Mobject):
|
|||
)))
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def match_colors(self, pmobject: PMobject):
|
||||
self.data["rgba"][:] = resize_with_interpolation(
|
||||
pmobject.data["rgba"], self.get_num_points()
|
||||
)
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def filter_out(self, condition: Callable[[np.ndarray], bool]):
|
||||
for mob in self.family_members_with_points():
|
||||
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
|
||||
"""
|
||||
function is any map from R^3 to R
|
||||
|
@ -81,6 +84,7 @@ class PMobject(Mobject):
|
|||
mob.data[:] = mob.data[indices]
|
||||
return self
|
||||
|
||||
@Mobject.affects_data
|
||||
def ingest_submobjects(self):
|
||||
self.data = np.vstack([
|
||||
sm.data for sm in self.get_family()
|
||||
|
@ -91,6 +95,7 @@ class PMobject(Mobject):
|
|||
index = alpha * (self.get_num_points() - 1)
|
||||
return self.get_points()[int(index)]
|
||||
|
||||
@Mobject.affects_data
|
||||
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
|
||||
lower_index = int(a * pmobject.get_num_points())
|
||||
upper_index = int(b * pmobject.get_num_points())
|
||||
|
|
|
@ -72,6 +72,7 @@ class Surface(Mobject):
|
|||
# To be implemented in subclasses
|
||||
return (u, v, 0.0)
|
||||
|
||||
@Mobject.affects_data
|
||||
def init_points(self):
|
||||
dim = self.dim
|
||||
nu, nv = self.resolution
|
||||
|
@ -130,6 +131,7 @@ class Surface(Mobject):
|
|||
)
|
||||
return normalize_along_axis(normals, 1)
|
||||
|
||||
@Mobject.affects_data
|
||||
def pointwise_become_partial(
|
||||
self,
|
||||
smobject: "Surface",
|
||||
|
@ -298,6 +300,7 @@ class TexturedSurface(Surface):
|
|||
**kwargs
|
||||
)
|
||||
|
||||
@Mobject.affects_data
|
||||
def init_points(self):
|
||||
surf = self.uv_surface
|
||||
nu, nv = surf.resolution
|
||||
|
@ -315,6 +318,7 @@ class TexturedSurface(Surface):
|
|||
super().init_uniforms()
|
||||
self.uniforms["num_textures"] = self.num_textures
|
||||
|
||||
@Mobject.affects_data
|
||||
def set_opacity(self, opacity: float | Iterable[float]):
|
||||
op_arr = np.array(listify(opacity))
|
||||
self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))
|
||||
|
|
|
@ -233,6 +233,7 @@ class VMobject(Mobject):
|
|||
self.set_stroke(color, width, background=background)
|
||||
return self
|
||||
|
||||
@Mobject.affects_family_data
|
||||
def set_style(
|
||||
self,
|
||||
fill_color: ManimColor | Iterable[ManimColor] | None = None,
|
||||
|
@ -1071,6 +1072,7 @@ class VMobject(Mobject):
|
|||
return self.data["joint_product"]
|
||||
|
||||
self.needs_new_joint_products = False
|
||||
self._data_has_changed = True
|
||||
|
||||
points = self.get_points()
|
||||
|
||||
|
@ -1109,6 +1111,11 @@ class VMobject(Mobject):
|
|||
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
|
||||
return self.data["joint_product"]
|
||||
|
||||
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject):
|
||||
for mob in [self, vmobject1, vmobject2]:
|
||||
mob.get_joint_products()
|
||||
super().lock_matching_data(vmobject1, vmobject2)
|
||||
|
||||
def triggers_refreshed_triangulation(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, refresh=True, **kwargs):
|
||||
|
@ -1119,10 +1126,11 @@ class VMobject(Mobject):
|
|||
return self
|
||||
return wrapper
|
||||
|
||||
@triggers_refreshed_triangulation
|
||||
def set_points(self, points: Vect3Array):
|
||||
assert(len(points) == 0 or len(points) % 2 == 1)
|
||||
super().set_points(points)
|
||||
self.refresh_triangulation()
|
||||
self.get_joint_products(refresh=True)
|
||||
return self
|
||||
|
||||
@triggers_refreshed_triangulation
|
||||
|
|
|
@ -189,7 +189,6 @@ class Scene(object):
|
|||
"Press `command + q` or `esc` to quit"
|
||||
)
|
||||
self.skip_animations = False
|
||||
self.refresh_static_mobjects()
|
||||
while not self.is_window_closing():
|
||||
self.update_frame(1 / self.camera.fps)
|
||||
|
||||
|
@ -251,7 +250,6 @@ class Scene(object):
|
|||
|
||||
# Operation to run after each ipython command
|
||||
def post_cell_func():
|
||||
self.refresh_static_mobjects()
|
||||
if not self.is_window_closing():
|
||||
self.update_frame(dt=0, ignore_skipping=True)
|
||||
self.save_state()
|
||||
|
@ -562,8 +560,6 @@ class Scene(object):
|
|||
self.real_animation_start_time = time.time()
|
||||
self.virtual_animation_start_time = self.time
|
||||
|
||||
self.refresh_static_mobjects()
|
||||
|
||||
def post_play(self):
|
||||
if not self.skip_animations:
|
||||
self.file_writer.end_animation()
|
||||
|
@ -574,10 +570,6 @@ class Scene(object):
|
|||
|
||||
self.num_plays += 1
|
||||
|
||||
def refresh_static_mobjects(self) -> None:
|
||||
for mobject in self.mobjects:
|
||||
mobject._data_has_changed = True
|
||||
|
||||
def begin_animations(self, animations: Iterable[Animation]) -> None:
|
||||
for animation in animations:
|
||||
animation.begin()
|
||||
|
@ -652,7 +644,6 @@ class Scene(object):
|
|||
self.emit_frame()
|
||||
if stop_condition is not None and stop_condition():
|
||||
break
|
||||
self.refresh_static_mobjects()
|
||||
self.post_play()
|
||||
|
||||
def hold_loop(self):
|
||||
|
@ -712,13 +703,11 @@ class Scene(object):
|
|||
if self.undo_stack:
|
||||
self.redo_stack.append(self.get_state())
|
||||
self.restore_state(self.undo_stack.pop())
|
||||
self.refresh_static_mobjects()
|
||||
|
||||
def redo(self):
|
||||
if self.redo_stack:
|
||||
self.undo_stack.append(self.get_state())
|
||||
self.restore_state(self.redo_stack.pop())
|
||||
self.refresh_static_mobjects()
|
||||
|
||||
def checkpoint_paste(self, skip: bool = False, record: bool = False):
|
||||
"""
|
||||
|
|
|
@ -51,16 +51,16 @@ class ShaderWrapper(object):
|
|||
self.depth_test = depth_test
|
||||
self.render_primitive = render_primitive
|
||||
|
||||
self.vbo = None
|
||||
self.ibo = None
|
||||
self.vao = None
|
||||
|
||||
self.init_program_code()
|
||||
self.init_program()
|
||||
if texture_paths is not None:
|
||||
self.init_textures(texture_paths)
|
||||
self.refresh_id()
|
||||
|
||||
self.vbo = None
|
||||
self.ibo = None
|
||||
self.vao = None
|
||||
|
||||
def init_program_code(self) -> None:
|
||||
def get_code(name: str) -> str | None:
|
||||
return get_shader_code_from_file(
|
||||
|
@ -100,15 +100,16 @@ class ShaderWrapper(object):
|
|||
self.render_primitive == shader_wrapper.render_primitive,
|
||||
))
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
def copy(self):
|
||||
result = copy.copy(self)
|
||||
result.ctx = self.ctx
|
||||
result.vert_data = self.vert_data.copy()
|
||||
result.vert_indices = self.vert_indices.copy()
|
||||
if self.uniforms:
|
||||
result.uniforms = {key: np.array(value) for key, value in self.uniforms.items()}
|
||||
result.vao = None
|
||||
result.vbo = None
|
||||
result.ibo = None
|
||||
return result
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
|
@ -219,7 +220,6 @@ class ShaderWrapper(object):
|
|||
self.update_program_uniforms()
|
||||
|
||||
def render(self):
|
||||
# TODO, generate on the fly?
|
||||
assert(self.vao is not None)
|
||||
self.vao.render()
|
||||
|
||||
|
@ -242,7 +242,8 @@ class ShaderWrapper(object):
|
|||
self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32))
|
||||
return self.ibo
|
||||
|
||||
def get_vao(self, refresh: bool = True):
|
||||
def generate_vao(self, refresh: bool = True):
|
||||
self.release()
|
||||
# Data buffer
|
||||
vbo = self.get_vertex_buffer_object(refresh)
|
||||
ibo = self.get_index_buffer_object(refresh)
|
||||
|
@ -258,10 +259,7 @@ class ShaderWrapper(object):
|
|||
def release(self):
|
||||
for obj in (self.vbo, self.ibo, self.vao):
|
||||
if obj is not None:
|
||||
try:
|
||||
obj.release()
|
||||
except AttributeError:
|
||||
pass
|
||||
obj.release()
|
||||
self.vbo = None
|
||||
self.ibo = None
|
||||
self.vao = None
|
||||
|
|
Loading…
Add table
Reference in a new issue