Refactor config.py

This commit is contained in:
Grant Sanderson 2022-12-20 09:31:02 -08:00
parent a26fe605b3
commit 9f71f87278

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import argparse import argparse
from argparse import Namespace
import colour import colour
from contextlib import contextmanager
import importlib import importlib
import inspect import inspect
import os import os
@ -13,6 +15,10 @@ from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.init_config import init_customization from manimlib.utils.init_config import init_customization
from manimlib.constants import FRAME_HEIGHT from manimlib.constants import FRAME_HEIGHT
from typing import TYPE_CHECKING
if TYPE_CHECKING:
Module = importlib.util.types.ModuleType
__config_file__ = "custom_config.yml" __config_file__ = "custom_config.yml"
@ -69,7 +75,7 @@ def parse_cli():
parser.add_argument( parser.add_argument(
"-p", "--presenter_mode", "-p", "--presenter_mode",
action="store_true", action="store_true",
help="Scene will stay paused during wait calls until " help="Scene will stay paused during wait calls until " + \
"space bar or right arrow is hit, like a slide show" "space bar or right arrow is hit, like a slide show"
) )
parser.add_argument( parser.add_argument(
@ -118,18 +124,18 @@ def parse_cli():
) )
parser.add_argument( parser.add_argument(
"-n", "--start_at_animation_number", "-n", "--start_at_animation_number",
help="Start rendering not from the first animation, but " help="Start rendering not from the first animation, but " + \
"from another, specified by its index. If you pass " "from another, specified by its index. If you pass " + \
"in two comma separated values, e.g. \"3,6\", it will end " "in two comma separated values, e.g. \"3,6\", it will end " + \
"the rendering at the second value", "the rendering at the second value",
) )
parser.add_argument( parser.add_argument(
"-e", "--embed", "-e", "--embed",
nargs="?", nargs="?",
const="", const="",
help="Creates a new file where the line `self.embed` is inserted " help="Creates a new file where the line `self.embed` is inserted " + \
"into the Scenes construct method. " "into the Scenes construct method. " + \
"If a string is passed in, the line will be inserted below the " "If a string is passed in, the line will be inserted below the " + \
"last line of code including that string." "last line of code including that string."
) )
parser.add_argument( parser.add_argument(
@ -172,6 +178,7 @@ def parse_cli():
help="Level of messages to Display, can be DEBUG / INFO / WARNING / ERROR / CRITICAL" help="Level of messages to Display, can be DEBUG / INFO / WARNING / ERROR / CRITICAL"
) )
args = parser.parse_args() args = parser.parse_args()
args.write_file = any([args.write_file, args.open, args.finder])
return args return args
except argparse.ArgumentError as err: except argparse.ArgumentError as err:
log.error(str(err)) log.error(str(err))
@ -184,7 +191,7 @@ def get_manim_dir():
return os.path.abspath(os.path.join(manimlib_dir, "..")) return os.path.abspath(os.path.join(manimlib_dir, ".."))
def get_module(file_name: str | None): def get_module(file_name: str | None) -> Module:
if file_name is None: if file_name is None:
return None return None
module_name = file_name.replace(os.sep, ".").replace(".py", "") module_name = file_name.replace(os.sep, ".").replace(".py", "")
@ -271,6 +278,15 @@ def get_module_with_inserted_embed_line(
return module return module
def get_scene_module(args: Namespace) -> Module:
if args.embed is None:
return get_module(args.file)
else:
return get_module_with_inserted_embed_line(
args.file, args.scene_names[0], args.embed
)
def get_custom_config(): def get_custom_config():
global __config_file__ global __config_file__
@ -278,50 +294,49 @@ def get_custom_config():
if os.path.exists(global_defaults_file): if os.path.exists(global_defaults_file):
with open(global_defaults_file, "r") as file: with open(global_defaults_file, "r") as file:
config = yaml.safe_load(file) custom_config = yaml.safe_load(file)
if os.path.exists(__config_file__): if os.path.exists(__config_file__):
with open(__config_file__, "r") as file: with open(__config_file__, "r") as file:
local_defaults = yaml.safe_load(file) local_defaults = yaml.safe_load(file)
if local_defaults: if local_defaults:
config = merge_dicts_recursively( custom_config = merge_dicts_recursively(
config, custom_config,
local_defaults, local_defaults,
) )
else: else:
with open(__config_file__, "r") as file: with open(__config_file__, "r") as file:
config = yaml.safe_load(file) custom_config = yaml.safe_load(file)
return config # Check temporary storage(custom_config)
if custom_config["directories"]["temporary_storage"] == "" and sys.platform == "win32":
def check_temporary_storage(config):
if config["directories"]["temporary_storage"] == "" and sys.platform == "win32":
log.warning( log.warning(
"You may be using Windows platform and have not specified the path of" "You may be using Windows platform and have not specified the path of" + \
" `temporary_storage`, which may cause OSError. So it is recommended" " `temporary_storage`, which may cause OSError. So it is recommended" + \
" to specify the `temporary_storage` in the config file (.yml)" " to specify the `temporary_storage` in the config file (.yml)"
) )
return custom_config
def get_configuration(args):
def init_global_config(config_file):
global __config_file__ global __config_file__
# ensure __config_file__ always exists # ensure __config_file__ always exists
if args.config_file is not None: if config_file is not None:
if not os.path.exists(args.config_file): if not os.path.exists(config_file):
log.error(f"Can't find {args.config_file}.") log.error(f"Can't find {config_file}.")
if sys.platform == 'win32': if sys.platform == 'win32':
log.info(f"Copying default configuration file to {args.config_file}...") log.info(f"Copying default configuration file to {config_file}...")
os.system(f"copy default_config.yml {args.config_file}") os.system(f"copy default_config.yml {config_file}")
elif sys.platform in ["linux2", "darwin"]: elif sys.platform in ["linux2", "darwin"]:
log.info(f"Copying default configuration file to {args.config_file}...") log.info(f"Copying default configuration file to {config_file}...")
os.system(f"cp default_config.yml {args.config_file}") os.system(f"cp default_config.yml {config_file}")
else: else:
log.info("Please create the configuration file manually.") log.info("Please create the configuration file manually.")
log.info("Read configuration from default_config.yml.") log.info("Read configuration from default_config.yml.")
else: else:
__config_file__ = args.config_file __config_file__ = config_file
global_defaults_file = os.path.join(get_manim_dir(), "manimlib", "default_config.yml") global_defaults_file = os.path.join(get_manim_dir(), "manimlib", "default_config.yml")
@ -336,17 +351,28 @@ def get_configuration(args):
f" `{__config_file__}`, or run `manimgl --config`" f" `{__config_file__}`, or run `manimgl --config`"
) )
custom_config = get_custom_config()
check_temporary_storage(custom_config)
write_file = any([args.write_file, args.open, args.finder]) def get_file_ext(args: Namespace) -> str:
if args.transparent: if args.transparent:
file_ext = ".mov" file_ext = ".mov"
elif args.gif: elif args.gif:
file_ext = ".gif" file_ext = ".gif"
else: else:
file_ext = ".mp4" file_ext = ".mp4"
return file_ext
def get_animations_numbers(args: Namespace) -> tuple[int | None, int | None]:
stan = args.start_at_animation_number
if stan is None:
return (None, None)
elif "," in stan:
return tuple(map(int, stan.split(",")))
else:
return int(stan), None
def get_output_directory(args: Namespace, custom_config: dict) -> str:
dir_config = custom_config["directories"] dir_config = custom_config["directories"]
output_directory = args.video_dir or dir_config["output"] output_directory = args.video_dir or dir_config["output"]
if dir_config["mirror_module_path"] and args.file: if dir_config["mirror_module_path"] and args.file:
@ -356,16 +382,19 @@ def get_configuration(args):
if ext.startswith("_"): if ext.startswith("_"):
ext = ext[1:] ext = ext[1:]
output_directory = os.path.join(output_directory, ext) output_directory = os.path.join(output_directory, ext)
return output_directory
file_writer_config = {
"write_to_movie": not args.skip_animations and write_file, def get_file_writer_config(args: Namespace, custom_config: dict) -> dict:
return {
"write_to_movie": not args.skip_animations and args.write_file,
"break_into_partial_movies": custom_config["break_into_partial_movies"], "break_into_partial_movies": custom_config["break_into_partial_movies"],
"save_last_frame": args.skip_animations and write_file, "save_last_frame": args.skip_animations and args.write_file,
"save_pngs": args.save_pngs, "save_pngs": args.save_pngs,
# If -t is passed in (for transparent), this will be RGBA # If -t is passed in (for transparent), this will be RGBA
"png_mode": "RGBA" if args.transparent else "RGB", "png_mode": "RGBA" if args.transparent else "RGB",
"movie_file_extension": file_ext, "movie_file_extension": get_file_ext(args),
"output_directory": output_directory, "output_directory": get_output_directory(args, custom_config),
"file_name": args.file_name, "file_name": args.file_name,
"input_file_path": args.file or "", "input_file_path": args.file or "",
"open_file_upon_completion": args.open, "open_file_upon_completion": args.open,
@ -373,61 +402,26 @@ def get_configuration(args):
"quiet": args.quiet, "quiet": args.quiet,
} }
if args.embed is not None:
module = get_module_with_inserted_embed_line(
args.file, args.scene_names[0], args.embed
)
else:
module = get_module(args.file)
config = {
"module": module,
"scene_names": args.scene_names,
"file_writer_config": file_writer_config,
"quiet": args.quiet or args.write_all,
"write_all": args.write_all,
"skip_animations": args.skip_animations,
"start_at_animation_number": args.start_at_animation_number,
"end_at_animation_number": None,
"preview": not write_file,
"presenter_mode": args.presenter_mode,
"leave_progress_bars": args.leave_progress_bars,
"show_animation_progress": args.show_animation_progress,
}
# Camera configuration
config["camera_config"] = get_camera_configuration(args, custom_config)
def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict:
# Default to making window half the screen size # Default to making window half the screen size
# but make it full screen if -f is passed in # but make it full screen if -f is passed in
monitors = get_monitors() monitors = get_monitors()
mon_index = custom_config["window_monitor"] mon_index = custom_config["window_monitor"]
monitor = monitors[min(mon_index, len(monitors) - 1)] monitor = monitors[min(mon_index, len(monitors) - 1)]
aspect_ratio = config["camera_config"]["pixel_width"] / config["camera_config"]["pixel_height"] aspect_ratio = camera_config["pixel_width"] / camera_config["pixel_height"]
window_width = monitor.width window_width = monitor.width
if not (args.full_screen or custom_config["full_screen"]): if not (args.full_screen or custom_config["full_screen"]):
window_width //= 2 window_width //= 2
window_height = int(window_width / aspect_ratio) window_height = int(window_width / aspect_ratio)
config["window_config"] = { return {
"size": (window_width, window_height), "size": (window_width, window_height),
} }
# Arguments related to skipping
stan = config["start_at_animation_number"]
if stan is not None:
if "," in stan:
start, end = stan.split(",")
config["start_at_animation_number"] = int(start)
config["end_at_animation_number"] = int(end)
else:
config["start_at_animation_number"] = int(stan)
return config def get_camera_configuration(args: Namespace, custom_config: dict) -> dict:
def get_camera_configuration(args, custom_config):
camera_config = {} camera_config = {}
camera_resolutions = get_custom_config()["camera_resolutions"] camera_resolutions = custom_config["camera_resolutions"]
if args.resolution: if args.resolution:
resolution = args.resolution resolution = args.resolution
elif args.low_quality: elif args.low_quality:
@ -444,7 +438,7 @@ def get_camera_configuration(args, custom_config):
if args.fps: if args.fps:
fps = int(args.fps) fps = int(args.fps)
else: else:
fps = get_custom_config()["fps"] fps = custom_config["fps"]
width_str, height_str = resolution.split("x") width_str, height_str = resolution.split("x")
width = int(width_str) width = int(width_str)
@ -473,3 +467,28 @@ def get_camera_configuration(args, custom_config):
camera_config["background_opacity"] = 0 camera_config["background_opacity"] = 0
return camera_config return camera_config
def get_configuration(args: Namespace) -> dict:
init_global_config(args.config_file)
custom_config = get_custom_config()
camera_config = get_camera_configuration(args, custom_config)
window_config = get_window_config(args, custom_config, camera_config)
start, end = get_animations_numbers(args)
return {
"module": get_scene_module(args),
"scene_names": args.scene_names,
"file_writer_config": get_file_writer_config(args, custom_config),
"camera_config": camera_config,
"window_config": window_config,
"quiet": args.quiet or args.write_all,
"write_all": args.write_all,
"skip_animations": args.skip_animations,
"start_at_animation_number": start,
"end_at_animation_number": end,
"preview": not args.write_file,
"presenter_mode": args.presenter_mode,
"leave_progress_bars": args.leave_progress_bars,
"show_animation_progress": args.show_animation_progress,
}