mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
176 lines
6.3 KiB
Python
176 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import builtins
|
|
import importlib
|
|
import os
|
|
import sys
|
|
import sysconfig
|
|
|
|
from manimlib.config import manim_config
|
|
from manimlib.logger import log
|
|
|
|
Module = importlib.util.types.ModuleType
|
|
|
|
|
|
class ModuleLoader:
|
|
"""
|
|
Utility class to load a module from a file and handle its imports.
|
|
|
|
Most parts of this class are only needed for the reload functionality,
|
|
while the `get_module` method is the main entry point to import a module.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_module(file_name: str | None, is_during_reload=False) -> Module | None:
|
|
"""
|
|
Imports a module from a file and returns it.
|
|
|
|
During reload (when the user calls `reload()` in the IPython shell), we
|
|
also track the imported modules and reload them as well (they would be
|
|
cached otherwise). See the reload_manager where the reload parameter is set.
|
|
|
|
Note that `exec_module()` is called twice when reloading a module:
|
|
1. In exec_module_and_track_imports to track the imports
|
|
2. Here to actually execute the module again with the respective
|
|
imported modules reloaded.
|
|
"""
|
|
if file_name is None:
|
|
return None
|
|
|
|
module_name = file_name.replace(os.sep, ".").replace(".py", "")
|
|
spec = importlib.util.spec_from_file_location(module_name, file_name)
|
|
module = importlib.util.module_from_spec(spec)
|
|
|
|
if is_during_reload:
|
|
imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module)
|
|
reloaded_modules_tracker = set()
|
|
ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker)
|
|
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
@staticmethod
|
|
def _exec_module_and_track_imports(spec, module: Module) -> set[str]:
|
|
"""
|
|
Executes the given module (imports it) and returns all the modules that
|
|
are imported during its execution.
|
|
|
|
This is achieved by replacing the __import__ function with a custom one
|
|
that tracks the imported modules. At the end, the original __import__
|
|
built-in function is restored.
|
|
"""
|
|
imported_modules: set[str] = set()
|
|
original_import = builtins.__import__
|
|
|
|
def tracked_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
"""
|
|
Custom __import__ function that does exactly the same as the original
|
|
one, but also tracks the imported modules by means of adding their
|
|
names to a set.
|
|
"""
|
|
result = original_import(name, globals, locals, fromlist, level)
|
|
imported_modules.add(name)
|
|
return result
|
|
|
|
builtins.__import__ = tracked_import
|
|
|
|
try:
|
|
module_name = module.__name__
|
|
log.debug('Reloading module "%s"', module_name)
|
|
|
|
spec.loader.exec_module(module)
|
|
finally:
|
|
builtins.__import__ = original_import
|
|
|
|
return imported_modules
|
|
|
|
@staticmethod
|
|
def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]):
|
|
"""
|
|
Out of the given modules, reloads the ones that were not already imported.
|
|
|
|
We skip modules that are not user-defined (see `is_user_defined_module()`).
|
|
"""
|
|
for mod in modules:
|
|
if mod in reloaded_modules_tracker:
|
|
continue
|
|
|
|
if not ModuleLoader._is_user_defined_module(mod):
|
|
continue
|
|
|
|
module = sys.modules[mod]
|
|
ModuleLoader._deep_reload(module, reloaded_modules_tracker)
|
|
|
|
reloaded_modules_tracker.add(mod)
|
|
|
|
@staticmethod
|
|
def _is_user_defined_module(mod: str) -> bool:
|
|
"""
|
|
Returns whether the given module is user-defined or not.
|
|
|
|
A module is considered user-defined if
|
|
- it is not part of the standard library
|
|
- AND it is not an external library (site-packages or dist-packages)
|
|
"""
|
|
if mod not in sys.modules:
|
|
return False
|
|
|
|
if mod in sys.builtin_module_names:
|
|
return False
|
|
|
|
module = sys.modules[mod]
|
|
module_path = getattr(module, "__file__", None)
|
|
if module_path is None:
|
|
return False
|
|
module_path = os.path.abspath(module_path)
|
|
|
|
# External libraries (site-packages or dist-packages), e.g. numpy
|
|
if "site-packages" in module_path or "dist-packages" in module_path:
|
|
return False
|
|
|
|
# Standard lib
|
|
standard_lib_path = sysconfig.get_path("stdlib")
|
|
if module_path.startswith(standard_lib_path):
|
|
return False
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def _deep_reload(module: Module, reloaded_modules_tracker: set[str]):
|
|
"""
|
|
Recursively reloads modules imported by the given module.
|
|
|
|
Only user-defined modules are reloaded, see `is_user_defined_module()`.
|
|
"""
|
|
ignore_manimlib_modules = manim_config.ignore_manimlib_modules_on_reload
|
|
if ignore_manimlib_modules and module.__name__.startswith("manimlib"):
|
|
return
|
|
if module.__name__.startswith("manimlib.config"):
|
|
# We don't want to reload global config
|
|
return
|
|
|
|
if not hasattr(module, "__dict__"):
|
|
return
|
|
|
|
# Prevent reloading the same module multiple times
|
|
if module.__name__ in reloaded_modules_tracker:
|
|
return
|
|
reloaded_modules_tracker.add(module.__name__)
|
|
|
|
# Recurse for all imported modules
|
|
for _attr_name, attr_value in module.__dict__.items():
|
|
if isinstance(attr_value, Module):
|
|
if ModuleLoader._is_user_defined_module(attr_value.__name__):
|
|
ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker)
|
|
|
|
# Also reload modules that are part of a class or function
|
|
# e.g. when importing `from custom_module import CustomClass`
|
|
elif hasattr(attr_value, "__module__"):
|
|
attr_module_name = attr_value.__module__
|
|
if ModuleLoader._is_user_defined_module(attr_module_name):
|
|
attr_module = sys.modules[attr_module_name]
|
|
ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker)
|
|
|
|
# Reload
|
|
log.debug('Reloading module "%s"', module.__name__)
|
|
importlib.reload(module)
|