Simplify Scene.remove to not require a Mobject.get_ancestors call

This commit is contained in:
Grant Sanderson 2022-12-26 07:46:40 -07:00
parent fcff44a66b
commit 22c67df2ad
2 changed files with 34 additions and 9 deletions

View file

@ -33,6 +33,7 @@ from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.scene.scene_file_writer import SceneFileWriter
from manimlib.utils.family_ops import extract_mobject_family_members
from manimlib.utils.family_ops import recursive_mobject_remove
from manimlib.utils.iterables import list_difference_update
from typing import TYPE_CHECKING
@ -377,7 +378,7 @@ class Scene(object):
]
return self
def remove(self, *mobjects: Mobject):
def remove(self, *mobjects_to_remove: Mobject):
"""
Removes anything in mobjects from scenes mobject list, but in the event that one
of the items to be removed is a member of the family of an item in mobject_list,
@ -386,13 +387,9 @@ class Scene(object):
For example, if the scene includes Group(m1, m2, m3), and we call scene.remove(m1),
the desired behavior is for the scene to then include m2 and m3 (ungrouped).
"""
for mob in mobjects:
# First restructure self.mobjects so that parents/grandparents/etc. are replaced
# with their children, likewise for all ancestors in the extended family.
for ancestor in mob.get_ancestors(extended=True):
self.replace(ancestor, *ancestor.submobjects)
self.mobjects = list_difference_update(self.mobjects, mob.get_family())
return self
to_remove = set(extract_mobject_family_members(mobjects_to_remove))
new_mobjects, _ = recursive_mobject_remove(self.mobjects, to_remove)
self.mobjects = new_mobjects
def bring_to_front(self, *mobjects: Mobject):
self.add(*mobjects)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from typing import Iterable, List, Set, Tuple
from manimlib.mobject.mobject import Mobject
@ -18,3 +18,31 @@ def extract_mobject_family_members(
for sm in mob.get_family()
if (not exclude_pointless) or sm.has_points()
]
def recursive_mobject_remove(mobjects: List[Mobject], to_remove: Set[Mobject]) -> Tuple[List[Mobject], bool]:
"""
Takes in a list of mobjects, together with a set of mobjects to remove.
The first component of what's removed is a new list such that any mobject
with one of the elements from `to_remove` in its family is no longer in
the list, and in its place are its family members which aren't in `to_remove`
The second component is a boolean value indicating whether any removals were made
"""
result = []
found_in_list = False
for mob in mobjects:
if mob in to_remove:
found_in_list = True
continue
# Recursive call
sub_list, found_in_submobjects = recursive_mobject_remove(
mob.submobjects, to_remove
)
if found_in_submobjects:
result.extend(sub_list)
found_in_list = True
else:
result.append(mob)
return result, found_in_list