Reload user-defined modules during reload() (#2257)

* Experiment a lot with module loading

* Extract methods out of experimental mess

* Fix get module return type

* Only reload() modules during reload() command

* Remove unnecessary default parameter

* Add docstrings and logging statements

* Delete unwanted printout

* Improve logging messages

* Extract methods to a new class ModuleLoader

* Remove unused builtins import

* exec_module in any case at the end

* Clarify docstrings & move get_module method up in file

* Add more additionally excluded modules as array

* Distinguish between user-defined modules and external libraries like numpy

* Improved tracked_import docstring

* Remove _insert_embed suffix before logging

* Fix args.is_reload not defined error

* Refine logic to determine whether module is user-defined or not

* Fix list vs. set type annotations

* Improve docstrings & change order of early return

* Fix spelling mistake of "Reloading"

* Try out custom deep reload

* Make deep reload more robust

* Also reload modules imported as classes

* Move early return up to greatly improve performance

* Clean up comments

* Make methods of Module Loader "private"

* Add backticks around function in docstring

---------

Co-authored-by: Grant Sanderson <grant@3blue1brown.com>
This commit is contained in:
Splines 2024-12-06 01:18:10 +01:00 committed by GitHub
parent 94f6f0aa96
commit 6196daa5ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 188 additions and 16 deletions

View file

@ -13,6 +13,7 @@ import yaml
from functools import lru_cache
from manimlib.logger import log
from manimlib.module_loader import ModuleLoader
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.init_config import init_customization
@ -209,22 +210,12 @@ def get_manim_dir():
return os.path.abspath(os.path.join(manimlib_dir, ".."))
def get_module(file_name: str | None) -> Module:
if file_name is None:
return None
module_name = file_name.replace(os.sep, ".").replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_name)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def get_indent(line: str):
return len(line) - len(line.lstrip())
def get_module_with_inserted_embed_line(
file_name: str, scene_name: str, line_marker: str
file_name: str, scene_name: str, line_marker: str, is_during_reload
):
"""
This is hacky, but convenient. When user includes the argument "-e", it will try
@ -286,7 +277,7 @@ def get_module_with_inserted_embed_line(
with open(new_file, 'w') as fp:
fp.writelines(new_lines)
module = get_module(new_file)
module = ModuleLoader.get_module(new_file, is_during_reload)
# This is to pretend the module imported from the edited lines
# of code actually comes from the original file.
module.__file__ = file_name
@ -298,10 +289,11 @@ def get_module_with_inserted_embed_line(
def get_scene_module(args: Namespace) -> Module:
if args.embed is None:
return get_module(args.file)
return ModuleLoader.get_module(args.file)
else:
is_reload = args.is_reload if hasattr(args, "is_reload") else False
return get_module_with_inserted_embed_line(
args.file, args.scene_names[0], args.embed
args.file, args.scene_names[0], args.embed, is_reload
)

176
manimlib/module_loader.py Normal file
View file

@ -0,0 +1,176 @@
from __future__ import annotations
import builtins
import importlib
import os
import sys
import sysconfig
from manimlib.logger import log
Module = importlib.util.types.ModuleType
IGNORE_MANIMLIB_MODULES = True
class ModuleLoader:
"""
Utility class to load a module from a file and handle its imports.
Most parts of this class are only needed for the reload functionality,
while the `get_module` method is the main entry point to import a module.
"""
@staticmethod
def get_module(file_name: str | None, is_during_reload=False) -> Module | None:
"""
Imports a module from a file and returns it.
During reload (when the user calls `reload()` in the IPython shell), we
also track the imported modules and reload them as well (they would be
cached otherwise). See the reload_manager where the reload parameter is set.
Note that `exec_module()` is called twice when reloading a module:
1. In exec_module_and_track_imports to track the imports
2. Here to actually execute the module again with the respective
imported modules reloaded.
"""
if file_name is None:
return None
module_name = file_name.replace(os.sep, ".").replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_name)
module = importlib.util.module_from_spec(spec)
if is_during_reload:
imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module)
reloaded_modules_tracker = set()
ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker)
spec.loader.exec_module(module)
return module
@staticmethod
def _exec_module_and_track_imports(spec, module: Module) -> set[str]:
"""
Executes the given module (imports it) and returns all the modules that
are imported during its execution.
This is achieved by replacing the __import__ function with a custom one
that tracks the imported modules. At the end, the original __import__
built-in function is restored.
"""
imported_modules: set[str] = set()
original_import = builtins.__import__
def tracked_import(name, globals=None, locals=None, fromlist=(), level=0):
"""
Custom __import__ function that does exactly the same as the original
one, but also tracks the imported modules by means of adding their
names to a set.
"""
result = original_import(name, globals, locals, fromlist, level)
imported_modules.add(name)
return result
builtins.__import__ = tracked_import
try:
# Remove the "_insert_embed" suffix from the module name
module_name = module.__name__
if module.__name__.endswith("_insert_embed"):
module_name = module_name[:-13]
log.debug('Reloading module "%s"', module_name)
spec.loader.exec_module(module)
finally:
builtins.__import__ = original_import
return imported_modules
@staticmethod
def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]):
"""
Out of the given modules, reloads the ones that were not already imported.
We skip modules that are not user-defined (see `is_user_defined_module()`).
"""
for mod in modules:
if mod in reloaded_modules_tracker:
continue
if not ModuleLoader._is_user_defined_module(mod):
continue
module = sys.modules[mod]
ModuleLoader._deep_reload(module, reloaded_modules_tracker)
reloaded_modules_tracker.add(mod)
@staticmethod
def _is_user_defined_module(mod: str) -> bool:
"""
Returns whether the given module is user-defined or not.
A module is considered user-defined if
- it is not part of the standard library
- AND it is not an external library (site-packages or dist-packages)
"""
if mod not in sys.modules:
return False
if mod in sys.builtin_module_names:
return False
module = sys.modules[mod]
module_path = getattr(module, "__file__", None)
if module_path is None:
return False
module_path = os.path.abspath(module_path)
# External libraries (site-packages or dist-packages), e.g. numpy
if "site-packages" in module_path or "dist-packages" in module_path:
return False
# Standard lib
standard_lib_path = sysconfig.get_path("stdlib")
if module_path.startswith(standard_lib_path):
return False
return True
@staticmethod
def _deep_reload(module: Module, reloaded_modules_tracker: set[str]):
"""
Recursively reloads modules imported by the given module.
Only user-defined modules are reloaded, see `is_user_defined_module()`.
"""
if IGNORE_MANIMLIB_MODULES and module.__name__.startswith("manimlib"):
return
if not hasattr(module, "__dict__"):
return
# Prevent reloading the same module multiple times
if module.__name__ in reloaded_modules_tracker:
return
reloaded_modules_tracker.add(module.__name__)
# Recurse for all imported modules
for _attr_name, attr_value in module.__dict__.items():
if isinstance(attr_value, Module):
if ModuleLoader._is_user_defined_module(attr_value.__name__):
ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker)
# Also reload modules that are part of a class or function
# e.g. when importing `from custom_module import CustomClass`
elif hasattr(attr_value, "__module__"):
attr_module_name = attr_value.__module__
if ModuleLoader._is_user_defined_module(attr_module_name):
attr_module = sys.modules[attr_module_name]
ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker)
# Reload
log.debug('Reloading module "%s"', module.__name__)
importlib.reload(module)

View file

@ -19,6 +19,8 @@ class ReloadManager:
# The line number to load the scene from when reloading
start_at_line = None
is_reload = False
def set_new_start_at_line(self, start_at_line):
"""
Sets/Updates the line number to load the scene from when reloading.
@ -41,6 +43,7 @@ class ReloadManager:
scene.tear_down()
self.scenes = []
self.is_reload = True
except KeyboardInterrupt:
break
@ -59,6 +62,7 @@ class ReloadManager:
self.args.embed = str(overwrite_start_at_line)
# Args to Config
self.args.is_reload = self.is_reload
scene_config = manimlib.config.get_scene_config(self.args)
if self.window:
scene_config["existing_window"] = self.window # see scene initialization

View file

@ -21,7 +21,7 @@ from manimlib.animation.animation import prepare_animation
from manimlib.animation.fading import VFadeInThenOut
from manimlib.camera.camera import Camera
from manimlib.camera.camera_frame import CameraFrame
from manimlib.config import get_module
from manimlib.module_loader import ModuleLoader
from manimlib.constants import ARROW_SYMBOLS
from manimlib.constants import DEFAULT_WAIT_TIME
from manimlib.constants import RED
@ -231,7 +231,7 @@ class Scene(object):
# Create embedded IPython terminal configured to have access to
# the local namespace of the caller
caller_frame = inspect.currentframe().f_back
module = get_module(caller_frame.f_globals["__file__"])
module = ModuleLoader.get_module(caller_frame.f_globals["__file__"])
shell = InteractiveShellEmbed(
user_module=module,
display_banner=False,