Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/external dev #87

Merged
merged 4 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 3 additions & 1 deletion cares_reinforcement_learning/util/NetworkFactory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import logging

def create_DQN(args):
from cares_reinforcement_learning.algorithm.value import DQN
Expand Down Expand Up @@ -159,4 +160,5 @@ def create_network(self, algorithm, args):
return create_SAC(args)
elif algorithm == "TD3":
return create_TD3(args)
raise ValueError(f"Unkown algorithm: {algorithm}")
logging.warn(f"Algorithm: {algorithm} is not in the default cares_rl factory")
return None
18 changes: 9 additions & 9 deletions cares_reinforcement_learning/util/arguement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def algorithm_args(parent_parser):
parser_DoubleDQN.add_argument('--exploration_min', type=float, default=1e-3)
parser_DoubleDQN.add_argument('--exploration_decay', type=float, default=0.95)

return alg_parser
return alg_parser, alg_parsers

def parse_args():
def environment_parser():
parser = argparse.ArgumentParser(add_help=False) # Add an argument

parser.add_argument('--number_training_iterations', type=int, default=1, help="Total amount of training iterations to complete")
Expand All @@ -86,7 +86,7 @@ def parse_args():
parser.add_argument('--G', type=int, default=10, help="Number of learning updates each step of training")
parser.add_argument('--batch_size', type=int, default=32, help="Batch Size used during training")

parser.add_argument('--max_steps_exploration', type=int, default=10000, help="Total number of steps for exploration before training")
parser.add_argument('--max_steps_exploration', type=int, default=1000, help="Total number of steps for exploration before training")
parser.add_argument('--max_steps_training', type=int, default=100000, help="Total number of steps to train the algorithm")

parser.add_argument('--number_steps_per_evaluation', type=int, default=10000, help="The number of steps inbetween evaluation runs during training")
Expand All @@ -97,10 +97,10 @@ def parse_args():
parser.add_argument('--plot_frequency', type=int, default=100, help="How many steps between updating the running plot of the training and evaluation data during training")
parser.add_argument('--checkpoint_frequency', type=int, default=100, help="How many steps between saving check point models of the agent during training")

parser = algorithm_args(parent_parser=parser)
return parser

def create_parser():
parser = environment_parser()
parser, alg_parsers = algorithm_args(parent_parser=parser)
parser = environment_args(parent_parser=parser)

return vars(parser.parse_args()) # converts to a dictionary

if __name__ == '__main__':
print(parse_args())
return parser
11 changes: 7 additions & 4 deletions example/example_training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from cares_reinforcement_learning.util import EnvironmentFactory
from cares_reinforcement_learning.util import arguement_parser as ap

import example.policy_example as pbe
import example.value_example as vbe
import ppo_example as ppe
import cares_reinforcement_learning.train_loops.policy_loop as pbe
import cares_reinforcement_learning.train_loops.value_loop as vbe
import cares_reinforcement_learning.train_loops.ppo_loop as ppe

import gym
from gym import spaces
Expand All @@ -28,7 +28,8 @@ def set_seed(seed):
random.seed(seed)

def main():
args = ap.parse_args()
parser = ap.create_parser()
args = vars(parser.parse_args()) # converts to a dictionary

args["device"] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Device: {args['device']}")
Expand Down Expand Up @@ -57,6 +58,8 @@ def main():

logging.info(f"Algorithm: {args['algorithm']}")
agent = network_factory.create_network(args["algorithm"], args)
if agent == None:
raise ValueError(f"Unkown agent for default algorithms {args['algorithm']}")

memory = memory_factory.create_memory(args['memory'], args)
logging.info(f"Memory: {args['memory']}")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def test_create_network():

agent = factory.create_network("TD3", args)
assert isinstance(agent, TD3), "Failed to create TD3 agent"

with pytest.raises(ValueError):
factory.create_network("Unknown", args)
agent = factory.create_network("Unknown", args)
assert agent is None, f"Unkown failed to return None: returned {agent}"


def test_denormalize():
Expand Down