From 6196daa5ec49546a8d973b6866991586f98d0d56 Mon Sep 17 00:00:00 2001 From: Splines <37160523+Splines@users.noreply.github.com> Date: Fri, 6 Dec 2024 01:18:10 +0100 Subject: [PATCH] Reload user-defined modules during `reload()` (#2257) * Experiment a lot with module loading * Extract methods out of experimental mess * Fix get module return type * Only reload() modules during reload() command * Remove unnecessary default parameter * Add docstrings and logging statements * Delete unwanted printout * Improve logging messages * Extract methods to a new class ModuleLoader * Remove unused builtins import * exec_module in any case at the end * Clarify docstrings & move get_module method up in file * Add more additionally excluded modules as array * Distinguish between user-defined modules and external libraries like numpy * Improved tracked_import docstring * Remove _insert_embed suffix before logging * Fix args.is_reload not defined error * Refine logic to determine whether module is user-defined or not * Fix list vs. set type annotations * Improve docstrings & change order of early return * Fix spelling mistake of "Reloading" * Try out custom deep reload * Make deep reload more robust * Also reload modules imported as classes * Move early return up to greatly improve performance * Clean up comments * Make methods of Module Loader "private" * Add backticks around function in docstring --------- Co-authored-by: Grant Sanderson --- manimlib/config.py | 20 ++--- manimlib/module_loader.py | 176 +++++++++++++++++++++++++++++++++++++ manimlib/reload_manager.py | 4 + manimlib/scene/scene.py | 4 +- 4 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 manimlib/module_loader.py diff --git a/manimlib/config.py b/manimlib/config.py index 3e299459..4d0a7a99 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -13,6 +13,7 @@ import yaml from functools import lru_cache from manimlib.logger import log +from manimlib.module_loader import ModuleLoader from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.init_config import init_customization @@ -209,22 +210,12 @@ def get_manim_dir(): return os.path.abspath(os.path.join(manimlib_dir, "..")) -def get_module(file_name: str | None) -> Module: - if file_name is None: - return None - module_name = file_name.replace(os.sep, ".").replace(".py", "") - spec = importlib.util.spec_from_file_location(module_name, file_name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - def get_indent(line: str): return len(line) - len(line.lstrip()) def get_module_with_inserted_embed_line( - file_name: str, scene_name: str, line_marker: str + file_name: str, scene_name: str, line_marker: str, is_during_reload ): """ This is hacky, but convenient. When user includes the argument "-e", it will try @@ -286,7 +277,7 @@ def get_module_with_inserted_embed_line( with open(new_file, 'w') as fp: fp.writelines(new_lines) - module = get_module(new_file) + module = ModuleLoader.get_module(new_file, is_during_reload) # This is to pretend the module imported from the edited lines # of code actually comes from the original file. module.__file__ = file_name @@ -298,10 +289,11 @@ def get_module_with_inserted_embed_line( def get_scene_module(args: Namespace) -> Module: if args.embed is None: - return get_module(args.file) + return ModuleLoader.get_module(args.file) else: + is_reload = args.is_reload if hasattr(args, "is_reload") else False return get_module_with_inserted_embed_line( - args.file, args.scene_names[0], args.embed + args.file, args.scene_names[0], args.embed, is_reload ) diff --git a/manimlib/module_loader.py b/manimlib/module_loader.py new file mode 100644 index 00000000..894f12a0 --- /dev/null +++ b/manimlib/module_loader.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import builtins +import importlib +import os +import sys +import sysconfig + +from manimlib.logger import log + +Module = importlib.util.types.ModuleType + +IGNORE_MANIMLIB_MODULES = True + + +class ModuleLoader: + """ + Utility class to load a module from a file and handle its imports. + + Most parts of this class are only needed for the reload functionality, + while the `get_module` method is the main entry point to import a module. + """ + + @staticmethod + def get_module(file_name: str | None, is_during_reload=False) -> Module | None: + """ + Imports a module from a file and returns it. + + During reload (when the user calls `reload()` in the IPython shell), we + also track the imported modules and reload them as well (they would be + cached otherwise). See the reload_manager where the reload parameter is set. + + Note that `exec_module()` is called twice when reloading a module: + 1. In exec_module_and_track_imports to track the imports + 2. Here to actually execute the module again with the respective + imported modules reloaded. + """ + if file_name is None: + return None + + module_name = file_name.replace(os.sep, ".").replace(".py", "") + spec = importlib.util.spec_from_file_location(module_name, file_name) + module = importlib.util.module_from_spec(spec) + + if is_during_reload: + imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module) + reloaded_modules_tracker = set() + ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker) + + spec.loader.exec_module(module) + return module + + @staticmethod + def _exec_module_and_track_imports(spec, module: Module) -> set[str]: + """ + Executes the given module (imports it) and returns all the modules that + are imported during its execution. + + This is achieved by replacing the __import__ function with a custom one + that tracks the imported modules. At the end, the original __import__ + built-in function is restored. + """ + imported_modules: set[str] = set() + original_import = builtins.__import__ + + def tracked_import(name, globals=None, locals=None, fromlist=(), level=0): + """ + Custom __import__ function that does exactly the same as the original + one, but also tracks the imported modules by means of adding their + names to a set. + """ + result = original_import(name, globals, locals, fromlist, level) + imported_modules.add(name) + return result + + builtins.__import__ = tracked_import + + try: + # Remove the "_insert_embed" suffix from the module name + module_name = module.__name__ + if module.__name__.endswith("_insert_embed"): + module_name = module_name[:-13] + log.debug('Reloading module "%s"', module_name) + + spec.loader.exec_module(module) + finally: + builtins.__import__ = original_import + + return imported_modules + + @staticmethod + def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]): + """ + Out of the given modules, reloads the ones that were not already imported. + + We skip modules that are not user-defined (see `is_user_defined_module()`). + """ + for mod in modules: + if mod in reloaded_modules_tracker: + continue + + if not ModuleLoader._is_user_defined_module(mod): + continue + + module = sys.modules[mod] + ModuleLoader._deep_reload(module, reloaded_modules_tracker) + + reloaded_modules_tracker.add(mod) + + @staticmethod + def _is_user_defined_module(mod: str) -> bool: + """ + Returns whether the given module is user-defined or not. + + A module is considered user-defined if + - it is not part of the standard library + - AND it is not an external library (site-packages or dist-packages) + """ + if mod not in sys.modules: + return False + + if mod in sys.builtin_module_names: + return False + + module = sys.modules[mod] + module_path = getattr(module, "__file__", None) + if module_path is None: + return False + module_path = os.path.abspath(module_path) + + # External libraries (site-packages or dist-packages), e.g. numpy + if "site-packages" in module_path or "dist-packages" in module_path: + return False + + # Standard lib + standard_lib_path = sysconfig.get_path("stdlib") + if module_path.startswith(standard_lib_path): + return False + + return True + + @staticmethod + def _deep_reload(module: Module, reloaded_modules_tracker: set[str]): + """ + Recursively reloads modules imported by the given module. + + Only user-defined modules are reloaded, see `is_user_defined_module()`. + """ + if IGNORE_MANIMLIB_MODULES and module.__name__.startswith("manimlib"): + return + + if not hasattr(module, "__dict__"): + return + + # Prevent reloading the same module multiple times + if module.__name__ in reloaded_modules_tracker: + return + reloaded_modules_tracker.add(module.__name__) + + # Recurse for all imported modules + for _attr_name, attr_value in module.__dict__.items(): + if isinstance(attr_value, Module): + if ModuleLoader._is_user_defined_module(attr_value.__name__): + ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker) + + # Also reload modules that are part of a class or function + # e.g. when importing `from custom_module import CustomClass` + elif hasattr(attr_value, "__module__"): + attr_module_name = attr_value.__module__ + if ModuleLoader._is_user_defined_module(attr_module_name): + attr_module = sys.modules[attr_module_name] + ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker) + + # Reload + log.debug('Reloading module "%s"', module.__name__) + importlib.reload(module) diff --git a/manimlib/reload_manager.py b/manimlib/reload_manager.py index 3d58c97a..e7887930 100644 --- a/manimlib/reload_manager.py +++ b/manimlib/reload_manager.py @@ -19,6 +19,8 @@ class ReloadManager: # The line number to load the scene from when reloading start_at_line = None + is_reload = False + def set_new_start_at_line(self, start_at_line): """ Sets/Updates the line number to load the scene from when reloading. @@ -41,6 +43,7 @@ class ReloadManager: scene.tear_down() self.scenes = [] + self.is_reload = True except KeyboardInterrupt: break @@ -59,6 +62,7 @@ class ReloadManager: self.args.embed = str(overwrite_start_at_line) # Args to Config + self.args.is_reload = self.is_reload scene_config = manimlib.config.get_scene_config(self.args) if self.window: scene_config["existing_window"] = self.window # see scene initialization diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 8fa2c139..abea0ba8 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -21,7 +21,7 @@ from manimlib.animation.animation import prepare_animation from manimlib.animation.fading import VFadeInThenOut from manimlib.camera.camera import Camera from manimlib.camera.camera_frame import CameraFrame -from manimlib.config import get_module +from manimlib.module_loader import ModuleLoader from manimlib.constants import ARROW_SYMBOLS from manimlib.constants import DEFAULT_WAIT_TIME from manimlib.constants import RED @@ -231,7 +231,7 @@ class Scene(object): # Create embedded IPython terminal configured to have access to # the local namespace of the caller caller_frame = inspect.currentframe().f_back - module = get_module(caller_frame.f_globals["__file__"]) + module = ModuleLoader.get_module(caller_frame.f_globals["__file__"]) shell = InteractiveShellEmbed( user_module=module, display_banner=False,