mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Simplify Mobject.copy to just use pickle serialization
This commit is contained in:
parent
c3afc84bfe
commit
1b009a4b03
4 changed files with 9 additions and 56 deletions
|
@ -1,7 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import copy
|
|
||||||
import random
|
import random
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
@ -25,7 +24,6 @@ from manimlib.utils.iterables import list_update
|
||||||
from manimlib.utils.iterables import resize_array
|
from manimlib.utils.iterables import resize_array
|
||||||
from manimlib.utils.iterables import resize_preserving_order
|
from manimlib.utils.iterables import resize_preserving_order
|
||||||
from manimlib.utils.iterables import resize_with_interpolation
|
from manimlib.utils.iterables import resize_with_interpolation
|
||||||
from manimlib.utils.iterables import make_even
|
|
||||||
from manimlib.utils.iterables import listify
|
from manimlib.utils.iterables import listify
|
||||||
from manimlib.utils.bezier import interpolate
|
from manimlib.utils.bezier import interpolate
|
||||||
from manimlib.utils.bezier import integer_interpolate
|
from manimlib.utils.bezier import integer_interpolate
|
||||||
|
@ -482,64 +480,25 @@ class Mobject(object):
|
||||||
# Copying
|
# Copying
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
# TODO, either justify reason for shallow copy, or
|
|
||||||
# remove this redundancy everywhere
|
|
||||||
# return self.deepcopy()
|
|
||||||
|
|
||||||
parents = self.parents
|
|
||||||
self.parents = []
|
self.parents = []
|
||||||
copy_mobject = copy.copy(self)
|
return pickle.loads(pickle.dumps(self))
|
||||||
self.parents = parents
|
|
||||||
|
|
||||||
copy_mobject.data = dict(self.data)
|
|
||||||
for key in self.data:
|
|
||||||
copy_mobject.data[key] = self.data[key].copy()
|
|
||||||
|
|
||||||
copy_mobject.uniforms = dict(self.uniforms)
|
|
||||||
for key in self.uniforms:
|
|
||||||
if isinstance(self.uniforms[key], np.ndarray):
|
|
||||||
copy_mobject.uniforms[key] = self.uniforms[key].copy()
|
|
||||||
|
|
||||||
copy_mobject.submobjects = []
|
|
||||||
copy_mobject.add(*[sm.copy() for sm in self.submobjects])
|
|
||||||
copy_mobject.match_updaters(self)
|
|
||||||
|
|
||||||
copy_mobject.needs_new_bounding_box = self.needs_new_bounding_box
|
|
||||||
|
|
||||||
# Make sure any mobject or numpy array attributes are copied
|
|
||||||
family = self.get_family()
|
|
||||||
for attr, value in list(self.__dict__.items()):
|
|
||||||
if isinstance(value, Mobject) and value in family and value is not self:
|
|
||||||
setattr(copy_mobject, attr, value.copy())
|
|
||||||
if isinstance(value, np.ndarray):
|
|
||||||
setattr(copy_mobject, attr, value.copy())
|
|
||||||
if isinstance(value, ShaderWrapper):
|
|
||||||
setattr(copy_mobject, attr, value.copy())
|
|
||||||
return copy_mobject
|
|
||||||
|
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
parents = self.parents
|
# This used to be different from copy, so is now just here for backward compatibility
|
||||||
self.parents = []
|
return self.copy()
|
||||||
result = copy.deepcopy(self)
|
|
||||||
self.parents = parents
|
|
||||||
return result
|
|
||||||
|
|
||||||
def generate_target(self, use_deepcopy: bool = False):
|
def generate_target(self, use_deepcopy: bool = False):
|
||||||
|
# TODO, remove now pointless use_deepcopy arg
|
||||||
self.target = None # Prevent exponential explosion
|
self.target = None # Prevent exponential explosion
|
||||||
if use_deepcopy:
|
self.target = self.copy()
|
||||||
self.target = self.deepcopy()
|
|
||||||
else:
|
|
||||||
self.target = self.copy()
|
|
||||||
return self.target
|
return self.target
|
||||||
|
|
||||||
def save_state(self, use_deepcopy: bool = False):
|
def save_state(self, use_deepcopy: bool = False):
|
||||||
|
# TODO, remove now pointless use_deepcopy arg
|
||||||
if hasattr(self, "saved_state"):
|
if hasattr(self, "saved_state"):
|
||||||
# Prevent exponential growth of data
|
# Prevent exponential growth of data
|
||||||
self.saved_state = None
|
self.saved_state = None
|
||||||
if use_deepcopy:
|
self.saved_state = self.copy()
|
||||||
self.saved_state = self.deepcopy()
|
|
||||||
else:
|
|
||||||
self.saved_state = self.copy()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def restore(self):
|
def restore(self):
|
||||||
|
@ -1473,7 +1432,7 @@ class Mobject(object):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def push_self_into_submobjects(self):
|
def push_self_into_submobjects(self):
|
||||||
copy = self.deepcopy()
|
copy = self.copy()
|
||||||
copy.set_submobjects([])
|
copy.set_submobjects([])
|
||||||
self.resize_points(0)
|
self.resize_points(0)
|
||||||
self.add(copy)
|
self.add(copy)
|
||||||
|
|
|
@ -274,6 +274,3 @@ class BarChart(VGroup):
|
||||||
(value / self.max_value) * self.height
|
(value / self.max_value) * self.height
|
||||||
)
|
)
|
||||||
bar.move_to(bar_bottom, DOWN)
|
bar.move_to(bar_bottom, DOWN)
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return self.deepcopy()
|
|
||||||
|
|
|
@ -123,9 +123,6 @@ class LabelledString(_StringSVG, ABC):
|
||||||
self.group_substrs = self.get_group_substrs()
|
self.group_substrs = self.get_group_substrs()
|
||||||
self.submob_groups = self.get_submob_groups()
|
self.submob_groups = self.get_submob_groups()
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return self.deepcopy()
|
|
||||||
|
|
||||||
# Toolkits
|
# Toolkits
|
||||||
|
|
||||||
def get_substr(self, span: Span) -> str:
|
def get_substr(self, span: Span) -> str:
|
||||||
|
|
|
@ -326,7 +326,7 @@ class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup):
|
||||||
max_time_width = kwargs.pop("time_width", self.time_width)
|
max_time_width = kwargs.pop("time_width", self.time_width)
|
||||||
AnimationGroup.__init__(self, *[
|
AnimationGroup.__init__(self, *[
|
||||||
VShowPassingFlash(
|
VShowPassingFlash(
|
||||||
vmobject.deepcopy().set_stroke(width=stroke_width),
|
vmobject.copy().set_stroke(width=stroke_width),
|
||||||
time_width=time_width,
|
time_width=time_width,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue