diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 5c29d3e6..6ad771f4 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -7,6 +7,7 @@ import platform import pyperclip import random import time +from functools import wraps from IPython.terminal import pt_inputhooks from IPython.terminal.embed import InteractiveShellEmbed @@ -37,6 +38,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.scene.scene_file_writer import SceneFileWriter from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.family_ops import recursive_mobject_remove +from manimlib.utils.iterables import batch_by_property from typing import TYPE_CHECKING @@ -110,6 +112,7 @@ class Scene(object): self.camera: Camera = Camera(**self.camera_config) self.file_writer = SceneFileWriter(self, **self.file_writer_config) self.mobjects: list[Mobject] = [self.camera.frame] + self.render_groups: list[Mobject] = [] self.id_to_mobject_map: dict[int, Mobject] = dict() self.num_plays: int = 0 self.time: float = 0 @@ -289,7 +292,7 @@ class Scene(object): def get_image(self) -> Image: if self.window is not None: self.camera.use_window_fbo(False) - self.camera.capture(*self.mobjects) + self.camera.capture(*self.render_groups) image = self.camera.get_image() if self.window is not None: self.camera.use_window_fbo(True) @@ -310,7 +313,7 @@ class Scene(object): if self.window: self.window.clear() - self.camera.capture(*self.mobjects) + self.camera.capture(*self.render_groups) if self.window: self.window.swap_buffers() @@ -369,6 +372,34 @@ class Scene(object): def get_mobject_family_members(self) -> list[Mobject]: return extract_mobject_family_members(self.mobjects) + def assemble_render_groups(self): + """ + Rendering is more efficient when VMobjects are grouped + together, so this function creates VGroups of all + clusters of adjacent VMobjects in the scene's mobject + list. + """ + for group in self.render_groups: + group.clear() + self.render_groups = [] + batches = batch_by_property( + self.mobjects, + lambda m: str(m.get_uniforms()) + str(m.apply_depth_test) + ) + self.render_groups = [ + batch[0].get_group_class()(*batch) + for batch, key in batches + ] + + def affects_mobject_list(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + func(self, *args, **kwargs) + self.assemble_render_groups() + return self + return wrapper + + @affects_mobject_list def add(self, *new_mobjects: Mobject): """ Mobjects will be displayed, from background to @@ -395,6 +426,7 @@ class Scene(object): )) return self + @affects_mobject_list def replace(self, mobject: Mobject, *replacements: Mobject): if mobject in self.mobjects: index = self.mobjects.index(mobject) @@ -405,6 +437,7 @@ class Scene(object): ] return self + @affects_mobject_list def remove(self, *mobjects_to_remove: Mobject): """ Removes anything in mobjects from scenes mobject list, but in the event that one @@ -422,11 +455,13 @@ class Scene(object): self.add(*mobjects) return self + @affects_mobject_list def bring_to_back(self, *mobjects: Mobject): self.remove(*mobjects) self.mobjects = list(mobjects) + self.mobjects return self + @affects_mobject_list def clear(self): self.mobjects = [] return self