Skip to content

Commit

Permalink
Added command line functions to plan and tune
Browse files Browse the repository at this point in the history
- "jaxplan plan" will plan
- "jaxplan tune" will tune
  • Loading branch information
mike-gimelfarb committed Jan 7, 2025
1 parent dcb6aed commit 016bd87
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 10 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ pip install pyRDDLGym-jax[extra,dashboard]

## Running from the Command Line

A basic run script is provided to run JaxPlan on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
A basic run script is provided to train JaxPlan on any RDDL problem:

```shell
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
jaxplan plan <domain> <instance> <method> <episodes>
```

where:
Expand All @@ -84,7 +84,7 @@ The ``method`` parameter supports three possible modes:
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones:

```shell
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
jaxplan plan Quadcopter 1 slp
```

## Running from Another Python Application
Expand Down Expand Up @@ -249,7 +249,7 @@ tuning.tune(key=42, log_file='path/to/log.csv')
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:

```shell
python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers>
```

where:
Expand Down
27 changes: 27 additions & 0 deletions pyRDDLGym_jax/entry_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse

from pyRDDLGym_jax.examples import run_plan, run_tune

def main():
parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
subparsers = parser.add_subparsers(dest="jaxplan", required=True)

# planning
parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
parser_plan.add_argument('args', nargs=argparse.REMAINDER)

# tuning
parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
parser_tune.add_argument('args', nargs=argparse.REMAINDER)

# dispatch
args = parser.parse_args()
if args.jaxplan == "plan":
run_plan.run_from_args(args.args)
elif args.jaxplan == "tune":
run_tune.run_from_args(args.args)
else:
parser.print_help()

if __name__ == "__main__":
main()
9 changes: 6 additions & 3 deletions pyRDDLGym_jax/examples/run_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def main(domain, instance, method, episodes=1):
controller.evaluate(env, episodes=episodes, verbose=True, render=True)
env.close()


if __name__ == "__main__":
args = sys.argv[1:]

def run_from_args(args):
if len(args) < 3:
print('python run_plan.py <domain> <instance> <method> [<episodes>]')
exit(1)
Expand All @@ -66,4 +65,8 @@ def main(domain, instance, method, episodes=1):
kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
if len(args) >= 4: kwargs['episodes'] = int(args[3])
main(**kwargs)


if __name__ == "__main__":
run_from_args(sys.argv[1:])

8 changes: 5 additions & 3 deletions pyRDDLGym_jax/examples/run_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
env.close()


if __name__ == "__main__":
args = sys.argv[1:]
def run_from_args(args):
if len(args) < 3:
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
exit(1)
Expand All @@ -88,4 +87,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
if len(args) >= 5: kwargs['iters'] = int(args[4])
if len(args) >= 6: kwargs['workers'] = int(args[5])
main(**kwargs)



if __name__ == "__main__":
run_from_args(sys.argv[1:])
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
python_requires=">=3.9",
package_data={'': ['*.cfg', '*.ico']},
include_package_data=True,
entry_points={
'console_scripts': [ 'jaxplan=pyRDDLGym_jax.entry_point:main'],
},
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
Expand Down

0 comments on commit 016bd87

Please sign in to comment.