From fc2fa5448c8dd72caaedfd925d25608ad4346b02 Mon Sep 17 00:00:00 2001 From: Mike Gimelfarb <35513382+mike-gimelfarb@users.noreply.github.com> Date: Tue, 7 Jan 2025 09:39:24 -0500 Subject: [PATCH] Allow passing custom config to run script --- README.md | 9 +++++---- pyRDDLGym_jax/examples/run_plan.py | 24 ++++++++++++++---------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index dd2271b..b82af3b 100644 --- a/README.md +++ b/README.md @@ -73,15 +73,16 @@ jaxplan plan where: - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file -- ``method`` is the planning method to use (i.e. drp, slp, replan) +- ``method`` is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below) - ``episodes`` is the (optional) number of episodes to evaluate the learned policy. -The ``method`` parameter supports three possible modes: +The ``method`` parameter supports four possible modes: - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf) - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744) -- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step. +- ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step +- any other argument is interpreted as a file path to a valid configuration file. -For example, the following will train JaxPlan on the Quadcopter domain with 4 drones: +For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config): ```shell jaxplan plan Quadcopter 1 slp diff --git a/pyRDDLGym_jax/examples/run_plan.py b/pyRDDLGym_jax/examples/run_plan.py index 8170e5f..b84db32 100644 --- a/pyRDDLGym_jax/examples/run_plan.py +++ b/pyRDDLGym_jax/examples/run_plan.py @@ -12,7 +12,7 @@ where: is the name of a domain located in the /Examples directory is the instance number - is either slp, drp, or replan + is slp, drp, replan, or a path to a valid .cfg file is the optional number of evaluation rollouts ''' import os @@ -32,12 +32,19 @@ def main(domain, instance, method, episodes=1): env = pyRDDLGym.make(domain, instance, vectorized=True) # load the config file with planner settings - abs_path = os.path.dirname(os.path.abspath(__file__)) - config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg') - if not os.path.isfile(config_path): - raise_warning(f'Config file {config_path} was not found, ' - f'using default_{method}.cfg.', 'red') - config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg') + if method in ['drp', 'slp', 'replan']: + abs_path = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg') + if not os.path.isfile(config_path): + raise_warning(f'Config file {config_path} was not found, ' + f'using default_{method}.cfg.', 'red') + config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg') + elif os.path.isfile(method): + config_path = method + else: + print('method must be slp, drp, replan, or a path to a valid .cfg file.') + exit(1) + planner_args, _, train_args = load_config(config_path) if 'dashboard' in train_args: train_args['dashboard'].launch() @@ -59,9 +66,6 @@ def run_from_args(args): if len(args) < 3: print('python run_plan.py []') exit(1) - if args[2] not in ['drp', 'slp', 'replan']: - print(' in [drp, slp, replan]') - exit(1) kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]} if len(args) >= 4: kwargs['episodes'] = int(args[3]) main(**kwargs)