diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index a0c5d947..a0ffd1e4 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from functools import wraps import inspect import os @@ -646,52 +647,37 @@ class Scene(object): # Helpers for interactive development - def get_state(self) -> tuple[list[tuple[Mobject, Mobject]], int]: - if self.undo_stack: - last_state = dict(self.undo_stack[-1]) - else: - last_state = {} - result = [] - n_changes = 0 - for mob in self.mobjects: - # If it hasn't changed since the last state, just point to the - # same copy as before - if mob in last_state and last_state[mob].looks_identical(mob): - result.append((mob, last_state[mob])) - else: - result.append((mob, mob.copy())) - n_changes += 1 - return result, n_changes + def get_state(self) -> SceneState: + return SceneState(self) - def restore_state(self, mobject_states: list[tuple[Mobject, Mobject]]): - self.mobjects = [mob.become(mob_copy) for mob, mob_copy in mobject_states] + def restore_state(self, scene_state: SceneState): + scene_state.restore_scene(self) def save_state(self) -> None: if not self.preview: return + state = self.get_state() + if self.undo_stack and state.mobjects_match(self.undo_stack[-1]): + return self.redo_stack = [] - state, n_changes = self.get_state() - if n_changes > 0: - self.undo_stack.append(state) - if len(self.undo_stack) > self.max_num_saved_states: - self.undo_stack.pop(0) + self.undo_stack.append(state) + if len(self.undo_stack) > self.max_num_saved_states: + self.undo_stack.pop(0) def undo(self): if self.undo_stack: - state, n_changes = self.get_state() - self.redo_stack.append(state) + self.redo_stack.append(self.get_state()) self.restore_state(self.undo_stack.pop()) self.refresh_static_mobjects() def redo(self): if self.redo_stack: - state, n_changes = self.get_state() - self.undo_stack.append(state) + self.undo_stack.append(self.get_state()) self.restore_state(self.redo_stack.pop()) self.refresh_static_mobjects() def checkpoint(self, key: str): - self.checkpoint_states[key] = self.get_state()[0] + self.checkpoint_states[key] = self.get_state() def revert_to_checkpoint(self, key: str): if key not in self.checkpoint_states: @@ -858,5 +844,49 @@ class Scene(object): pass +class SceneState(): + def __init__(self, scene: Scene, ignore: list[Mobject] | None = None): + self.time = scene.time + self.num_plays = scene.num_plays + self.mobjects_to_copies = OrderedDict.fromkeys(scene.mobjects) + if ignore: + for mob in ignore: + self.mobjects_to_copies.pop(mob, None) + + last_m2c = scene.undo_stack[-1].mobjects_to_copies if scene.undo_stack else dict() + for mob in self.mobjects_to_copies: + # If it hasn't changed since the last state, just point to the + # same copy as before + if mob in last_m2c and last_m2c[mob].looks_identical(mob): + self.mobjects_to_copies[mob] = last_m2c[mob] + else: + self.mobjects_to_copies[mob] = mob.copy() + + def __eq__(self, state: SceneState): + return all(( + self.time == state.time, + self.num_plays == state.num_plays, + self.mobjects_to_copies == state.mobjects_to_copies + )) + + def mobjects_match(self, state: SceneState): + return self.mobjects_to_copies == state.mobjects_to_copies + + def n_changes(self, state: SceneState): + m2c = state.mobjects_to_copies + return sum( + 1 - int(mob in m2c and mob.looks_identical(m2c[mob])) + for mob in self.mobjects_to_copies + ) + + def restore_scene(self, scene: Scene): + scene.time = self.time + scene.num_plays = self.num_plays + scene.mobjects = [ + mob.become(mob_copy) + for mob, mob_copy in self.mobjects_to_copies.items() + ] + + class EndSceneEarlyException(Exception): pass