Skip to content

Commit

Permalink
Merge pull request #87 from UoA-CARES/dev/external_dev
Browse files Browse the repository at this point in the history
Dev/external dev
  • Loading branch information
dvalenciar authored Oct 15, 2023
2 parents 548562e + be422cd commit f93f903
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 17 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 3 additions & 1 deletion cares_reinforcement_learning/util/NetworkFactory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import logging

def create_DQN(args):
from cares_reinforcement_learning.algorithm.value import DQN
Expand Down Expand Up @@ -159,4 +160,5 @@ def create_network(self, algorithm, args):
return create_SAC(args)
elif algorithm == "TD3":
return create_TD3(args)
raise ValueError(f"Unkown algorithm: {algorithm}")
logging.warn(f"Algorithm: {algorithm} is not in the default cares_rl factory")
return None
18 changes: 9 additions & 9 deletions cares_reinforcement_learning/util/arguement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def algorithm_args(parent_parser):
parser_DoubleDQN.add_argument('--exploration_min', type=float, default=1e-3)
parser_DoubleDQN.add_argument('--exploration_decay', type=float, default=0.95)

return alg_parser
return alg_parser, alg_parsers

def parse_args():
def environment_parser():
parser = argparse.ArgumentParser(add_help=False) # Add an argument

parser.add_argument('--number_training_iterations', type=int, default=1, help="Total amount of training iterations to complete")
Expand All @@ -86,7 +86,7 @@ def parse_args():
parser.add_argument('--G', type=int, default=10, help="Number of learning updates each step of training")
parser.add_argument('--batch_size', type=int, default=32, help="Batch Size used during training")

parser.add_argument('--max_steps_exploration', type=int, default=10000, help="Total number of steps for exploration before training")
parser.add_argument('--max_steps_exploration', type=int, default=1000, help="Total number of steps for exploration before training")
parser.add_argument('--max_steps_training', type=int, default=100000, help="Total number of steps to train the algorithm")

parser.add_argument('--number_steps_per_evaluation', type=int, default=10000, help="The number of steps inbetween evaluation runs during training")
Expand All @@ -97,10 +97,10 @@ def parse_args():
parser.add_argument('--plot_frequency', type=int, default=100, help="How many steps between updating the running plot of the training and evaluation data during training")
parser.add_argument('--checkpoint_frequency', type=int, default=100, help="How many steps between saving check point models of the agent during training")

parser = algorithm_args(parent_parser=parser)
return parser

def create_parser():
parser = environment_parser()
parser, alg_parsers = algorithm_args(parent_parser=parser)
parser = environment_args(parent_parser=parser)

return vars(parser.parse_args()) # converts to a dictionary

if __name__ == '__main__':
print(parse_args())
return parser
11 changes: 7 additions & 4 deletions example/example_training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from cares_reinforcement_learning.util import EnvironmentFactory
from cares_reinforcement_learning.util import arguement_parser as ap

import example.policy_example as pbe
import example.value_example as vbe
import ppo_example as ppe
import cares_reinforcement_learning.train_loops.policy_loop as pbe
import cares_reinforcement_learning.train_loops.value_loop as vbe
import cares_reinforcement_learning.train_loops.ppo_loop as ppe

import gym
from gym import spaces
Expand All @@ -28,7 +28,8 @@ def set_seed(seed):
random.seed(seed)

def main():
args = ap.parse_args()
parser = ap.create_parser()
args = vars(parser.parse_args()) # converts to a dictionary

args["device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Device: {args['device']}")
Expand Down Expand Up @@ -57,6 +58,8 @@ def main():

logging.info(f"Algorithm: {args['algorithm']}")
agent = network_factory.create_network(args["algorithm"], args)
if agent == None:
raise ValueError(f"Unkown agent for default algorithms {args['algorithm']}")

memory = memory_factory.create_memory(args['memory'], args)
logging.info(f"Memory: {args['memory']}")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def test_create_network():

agent = factory.create_network("TD3", args)
assert isinstance(agent, TD3), "Failed to create TD3 agent"

with pytest.raises(ValueError):
factory.create_network("Unknown", args)
agent = factory.create_network("Unknown", args)
assert agent is None, f"Unkown failed to return None: returned {agent}"


def test_denormalize():
Expand Down

0 comments on commit f93f903

Please sign in to comment.