mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Reload user-defined modules during reload()
(#2257)
* Experiment a lot with module loading * Extract methods out of experimental mess * Fix get module return type * Only reload() modules during reload() command * Remove unnecessary default parameter * Add docstrings and logging statements * Delete unwanted printout * Improve logging messages * Extract methods to a new class ModuleLoader * Remove unused builtins import * exec_module in any case at the end * Clarify docstrings & move get_module method up in file * Add more additionally excluded modules as array * Distinguish between user-defined modules and external libraries like numpy * Improved tracked_import docstring * Remove _insert_embed suffix before logging * Fix args.is_reload not defined error * Refine logic to determine whether module is user-defined or not * Fix list vs. set type annotations * Improve docstrings & change order of early return * Fix spelling mistake of "Reloading" * Try out custom deep reload * Make deep reload more robust * Also reload modules imported as classes * Move early return up to greatly improve performance * Clean up comments * Make methods of Module Loader "private" * Add backticks around function in docstring --------- Co-authored-by: Grant Sanderson <grant@3blue1brown.com>
This commit is contained in:
parent
94f6f0aa96
commit
6196daa5ec
4 changed files with 188 additions and 16 deletions
|
@ -13,6 +13,7 @@ import yaml
|
|||
from functools import lru_cache
|
||||
|
||||
from manimlib.logger import log
|
||||
from manimlib.module_loader import ModuleLoader
|
||||
from manimlib.utils.dict_ops import merge_dicts_recursively
|
||||
from manimlib.utils.init_config import init_customization
|
||||
|
||||
|
@ -209,22 +210,12 @@ def get_manim_dir():
|
|||
return os.path.abspath(os.path.join(manimlib_dir, ".."))
|
||||
|
||||
|
||||
def get_module(file_name: str | None) -> Module:
|
||||
if file_name is None:
|
||||
return None
|
||||
module_name = file_name.replace(os.sep, ".").replace(".py", "")
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_name)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def get_indent(line: str):
|
||||
return len(line) - len(line.lstrip())
|
||||
|
||||
|
||||
def get_module_with_inserted_embed_line(
|
||||
file_name: str, scene_name: str, line_marker: str
|
||||
file_name: str, scene_name: str, line_marker: str, is_during_reload
|
||||
):
|
||||
"""
|
||||
This is hacky, but convenient. When user includes the argument "-e", it will try
|
||||
|
@ -286,7 +277,7 @@ def get_module_with_inserted_embed_line(
|
|||
with open(new_file, 'w') as fp:
|
||||
fp.writelines(new_lines)
|
||||
|
||||
module = get_module(new_file)
|
||||
module = ModuleLoader.get_module(new_file, is_during_reload)
|
||||
# This is to pretend the module imported from the edited lines
|
||||
# of code actually comes from the original file.
|
||||
module.__file__ = file_name
|
||||
|
@ -298,10 +289,11 @@ def get_module_with_inserted_embed_line(
|
|||
|
||||
def get_scene_module(args: Namespace) -> Module:
|
||||
if args.embed is None:
|
||||
return get_module(args.file)
|
||||
return ModuleLoader.get_module(args.file)
|
||||
else:
|
||||
is_reload = args.is_reload if hasattr(args, "is_reload") else False
|
||||
return get_module_with_inserted_embed_line(
|
||||
args.file, args.scene_names[0], args.embed
|
||||
args.file, args.scene_names[0], args.embed, is_reload
|
||||
)
|
||||
|
||||
|
||||
|
|
176
manimlib/module_loader.py
Normal file
176
manimlib/module_loader.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import sysconfig
|
||||
|
||||
from manimlib.logger import log
|
||||
|
||||
Module = importlib.util.types.ModuleType
|
||||
|
||||
IGNORE_MANIMLIB_MODULES = True
|
||||
|
||||
|
||||
class ModuleLoader:
|
||||
"""
|
||||
Utility class to load a module from a file and handle its imports.
|
||||
|
||||
Most parts of this class are only needed for the reload functionality,
|
||||
while the `get_module` method is the main entry point to import a module.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_module(file_name: str | None, is_during_reload=False) -> Module | None:
|
||||
"""
|
||||
Imports a module from a file and returns it.
|
||||
|
||||
During reload (when the user calls `reload()` in the IPython shell), we
|
||||
also track the imported modules and reload them as well (they would be
|
||||
cached otherwise). See the reload_manager where the reload parameter is set.
|
||||
|
||||
Note that `exec_module()` is called twice when reloading a module:
|
||||
1. In exec_module_and_track_imports to track the imports
|
||||
2. Here to actually execute the module again with the respective
|
||||
imported modules reloaded.
|
||||
"""
|
||||
if file_name is None:
|
||||
return None
|
||||
|
||||
module_name = file_name.replace(os.sep, ".").replace(".py", "")
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_name)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
if is_during_reload:
|
||||
imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module)
|
||||
reloaded_modules_tracker = set()
|
||||
ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker)
|
||||
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _exec_module_and_track_imports(spec, module: Module) -> set[str]:
|
||||
"""
|
||||
Executes the given module (imports it) and returns all the modules that
|
||||
are imported during its execution.
|
||||
|
||||
This is achieved by replacing the __import__ function with a custom one
|
||||
that tracks the imported modules. At the end, the original __import__
|
||||
built-in function is restored.
|
||||
"""
|
||||
imported_modules: set[str] = set()
|
||||
original_import = builtins.__import__
|
||||
|
||||
def tracked_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
"""
|
||||
Custom __import__ function that does exactly the same as the original
|
||||
one, but also tracks the imported modules by means of adding their
|
||||
names to a set.
|
||||
"""
|
||||
result = original_import(name, globals, locals, fromlist, level)
|
||||
imported_modules.add(name)
|
||||
return result
|
||||
|
||||
builtins.__import__ = tracked_import
|
||||
|
||||
try:
|
||||
# Remove the "_insert_embed" suffix from the module name
|
||||
module_name = module.__name__
|
||||
if module.__name__.endswith("_insert_embed"):
|
||||
module_name = module_name[:-13]
|
||||
log.debug('Reloading module "%s"', module_name)
|
||||
|
||||
spec.loader.exec_module(module)
|
||||
finally:
|
||||
builtins.__import__ = original_import
|
||||
|
||||
return imported_modules
|
||||
|
||||
@staticmethod
|
||||
def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]):
|
||||
"""
|
||||
Out of the given modules, reloads the ones that were not already imported.
|
||||
|
||||
We skip modules that are not user-defined (see `is_user_defined_module()`).
|
||||
"""
|
||||
for mod in modules:
|
||||
if mod in reloaded_modules_tracker:
|
||||
continue
|
||||
|
||||
if not ModuleLoader._is_user_defined_module(mod):
|
||||
continue
|
||||
|
||||
module = sys.modules[mod]
|
||||
ModuleLoader._deep_reload(module, reloaded_modules_tracker)
|
||||
|
||||
reloaded_modules_tracker.add(mod)
|
||||
|
||||
@staticmethod
|
||||
def _is_user_defined_module(mod: str) -> bool:
|
||||
"""
|
||||
Returns whether the given module is user-defined or not.
|
||||
|
||||
A module is considered user-defined if
|
||||
- it is not part of the standard library
|
||||
- AND it is not an external library (site-packages or dist-packages)
|
||||
"""
|
||||
if mod not in sys.modules:
|
||||
return False
|
||||
|
||||
if mod in sys.builtin_module_names:
|
||||
return False
|
||||
|
||||
module = sys.modules[mod]
|
||||
module_path = getattr(module, "__file__", None)
|
||||
if module_path is None:
|
||||
return False
|
||||
module_path = os.path.abspath(module_path)
|
||||
|
||||
# External libraries (site-packages or dist-packages), e.g. numpy
|
||||
if "site-packages" in module_path or "dist-packages" in module_path:
|
||||
return False
|
||||
|
||||
# Standard lib
|
||||
standard_lib_path = sysconfig.get_path("stdlib")
|
||||
if module_path.startswith(standard_lib_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _deep_reload(module: Module, reloaded_modules_tracker: set[str]):
|
||||
"""
|
||||
Recursively reloads modules imported by the given module.
|
||||
|
||||
Only user-defined modules are reloaded, see `is_user_defined_module()`.
|
||||
"""
|
||||
if IGNORE_MANIMLIB_MODULES and module.__name__.startswith("manimlib"):
|
||||
return
|
||||
|
||||
if not hasattr(module, "__dict__"):
|
||||
return
|
||||
|
||||
# Prevent reloading the same module multiple times
|
||||
if module.__name__ in reloaded_modules_tracker:
|
||||
return
|
||||
reloaded_modules_tracker.add(module.__name__)
|
||||
|
||||
# Recurse for all imported modules
|
||||
for _attr_name, attr_value in module.__dict__.items():
|
||||
if isinstance(attr_value, Module):
|
||||
if ModuleLoader._is_user_defined_module(attr_value.__name__):
|
||||
ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker)
|
||||
|
||||
# Also reload modules that are part of a class or function
|
||||
# e.g. when importing `from custom_module import CustomClass`
|
||||
elif hasattr(attr_value, "__module__"):
|
||||
attr_module_name = attr_value.__module__
|
||||
if ModuleLoader._is_user_defined_module(attr_module_name):
|
||||
attr_module = sys.modules[attr_module_name]
|
||||
ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker)
|
||||
|
||||
# Reload
|
||||
log.debug('Reloading module "%s"', module.__name__)
|
||||
importlib.reload(module)
|
|
@ -19,6 +19,8 @@ class ReloadManager:
|
|||
# The line number to load the scene from when reloading
|
||||
start_at_line = None
|
||||
|
||||
is_reload = False
|
||||
|
||||
def set_new_start_at_line(self, start_at_line):
|
||||
"""
|
||||
Sets/Updates the line number to load the scene from when reloading.
|
||||
|
@ -41,6 +43,7 @@ class ReloadManager:
|
|||
scene.tear_down()
|
||||
|
||||
self.scenes = []
|
||||
self.is_reload = True
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
@ -59,6 +62,7 @@ class ReloadManager:
|
|||
self.args.embed = str(overwrite_start_at_line)
|
||||
|
||||
# Args to Config
|
||||
self.args.is_reload = self.is_reload
|
||||
scene_config = manimlib.config.get_scene_config(self.args)
|
||||
if self.window:
|
||||
scene_config["existing_window"] = self.window # see scene initialization
|
||||
|
|
|
@ -21,7 +21,7 @@ from manimlib.animation.animation import prepare_animation
|
|||
from manimlib.animation.fading import VFadeInThenOut
|
||||
from manimlib.camera.camera import Camera
|
||||
from manimlib.camera.camera_frame import CameraFrame
|
||||
from manimlib.config import get_module
|
||||
from manimlib.module_loader import ModuleLoader
|
||||
from manimlib.constants import ARROW_SYMBOLS
|
||||
from manimlib.constants import DEFAULT_WAIT_TIME
|
||||
from manimlib.constants import RED
|
||||
|
@ -231,7 +231,7 @@ class Scene(object):
|
|||
# Create embedded IPython terminal configured to have access to
|
||||
# the local namespace of the caller
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = get_module(caller_frame.f_globals["__file__"])
|
||||
module = ModuleLoader.get_module(caller_frame.f_globals["__file__"])
|
||||
shell = InteractiveShellEmbed(
|
||||
user_module=module,
|
||||
display_banner=False,
|
||||
|
|
Loading…
Add table
Reference in a new issue