3b1b-manim/manimlib/module_loader.py
Grant Sanderson bafea89ac9
Some checks are pending
docs / build up document and deploy (push) Waiting to run
Update InteractiveSceneEmbed (#2267)
* Comment tweak

* Directly print traceback

Since the shell.showtraceback is giving some issues

* Make InteracrtiveSceneEmbed into a class

This way it can keep track of it's internal shell; use of get_ipython has a finicky relationship with reloading.

* Move remaining checkpoint_paste logic into scene_embed.py

This involved making a few context managers for Scene: temp_record, temp_skip, temp_progress_bar, which seem useful in and of themselves.

* Change null key to be the empty string
2024-12-11 11:33:48 -08:00

176 lines
6.3 KiB
Python

from __future__ import annotations
import builtins
import importlib
import os
import sys
import sysconfig
from manimlib.config import manim_config
from manimlib.logger import log
Module = importlib.util.types.ModuleType
class ModuleLoader:
"""
Utility class to load a module from a file and handle its imports.
Most parts of this class are only needed for the reload functionality,
while the `get_module` method is the main entry point to import a module.
"""
@staticmethod
def get_module(file_name: str | None, is_during_reload=False) -> Module | None:
"""
Imports a module from a file and returns it.
During reload (when the user calls `reload()` in the IPython shell), we
also track the imported modules and reload them as well (they would be
cached otherwise). See the reload_manager where the reload parameter is set.
Note that `exec_module()` is called twice when reloading a module:
1. In exec_module_and_track_imports to track the imports
2. Here to actually execute the module again with the respective
imported modules reloaded.
"""
if file_name is None:
return None
module_name = file_name.replace(os.sep, ".").replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_name)
module = importlib.util.module_from_spec(spec)
if is_during_reload:
imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module)
reloaded_modules_tracker = set()
ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker)
spec.loader.exec_module(module)
return module
@staticmethod
def _exec_module_and_track_imports(spec, module: Module) -> set[str]:
"""
Executes the given module (imports it) and returns all the modules that
are imported during its execution.
This is achieved by replacing the __import__ function with a custom one
that tracks the imported modules. At the end, the original __import__
built-in function is restored.
"""
imported_modules: set[str] = set()
original_import = builtins.__import__
def tracked_import(name, globals=None, locals=None, fromlist=(), level=0):
"""
Custom __import__ function that does exactly the same as the original
one, but also tracks the imported modules by means of adding their
names to a set.
"""
result = original_import(name, globals, locals, fromlist, level)
imported_modules.add(name)
return result
builtins.__import__ = tracked_import
try:
module_name = module.__name__
log.debug('Reloading module "%s"', module_name)
spec.loader.exec_module(module)
finally:
builtins.__import__ = original_import
return imported_modules
@staticmethod
def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]):
"""
Out of the given modules, reloads the ones that were not already imported.
We skip modules that are not user-defined (see `is_user_defined_module()`).
"""
for mod in modules:
if mod in reloaded_modules_tracker:
continue
if not ModuleLoader._is_user_defined_module(mod):
continue
module = sys.modules[mod]
ModuleLoader._deep_reload(module, reloaded_modules_tracker)
reloaded_modules_tracker.add(mod)
@staticmethod
def _is_user_defined_module(mod: str) -> bool:
"""
Returns whether the given module is user-defined or not.
A module is considered user-defined if
- it is not part of the standard library
- AND it is not an external library (site-packages or dist-packages)
"""
if mod not in sys.modules:
return False
if mod in sys.builtin_module_names:
return False
module = sys.modules[mod]
module_path = getattr(module, "__file__", None)
if module_path is None:
return False
module_path = os.path.abspath(module_path)
# External libraries (site-packages or dist-packages), e.g. numpy
if "site-packages" in module_path or "dist-packages" in module_path:
return False
# Standard lib
standard_lib_path = sysconfig.get_path("stdlib")
if module_path.startswith(standard_lib_path):
return False
return True
@staticmethod
def _deep_reload(module: Module, reloaded_modules_tracker: set[str]):
"""
Recursively reloads modules imported by the given module.
Only user-defined modules are reloaded, see `is_user_defined_module()`.
"""
ignore_manimlib_modules = manim_config.ignore_manimlib_modules_on_reload
if ignore_manimlib_modules and module.__name__.startswith("manimlib"):
return
if module.__name__.startswith("manimlib.config"):
# We don't want to reload global manim_config
return
if not hasattr(module, "__dict__"):
return
# Prevent reloading the same module multiple times
if module.__name__ in reloaded_modules_tracker:
return
reloaded_modules_tracker.add(module.__name__)
# Recurse for all imported modules
for _attr_name, attr_value in module.__dict__.items():
if isinstance(attr_value, Module):
if ModuleLoader._is_user_defined_module(attr_value.__name__):
ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker)
# Also reload modules that are part of a class or function
# e.g. when importing `from custom_module import CustomClass`
elif hasattr(attr_value, "__module__"):
attr_module_name = attr_value.__module__
if ModuleLoader._is_user_defined_module(attr_module_name):
attr_module = sys.modules[attr_module_name]
ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker)
# Reload
log.debug('Reloading module "%s"', module.__name__)
importlib.reload(module)