From 97bc85f9463b1e85d5c4fe660edd057884b46806 Mon Sep 17 00:00:00 2001 From: retinfai Date: Wed, 1 Nov 2023 11:59:21 +1300 Subject: [PATCH 1/4] feat: make network factory dynamic --- .../util/NetworkFactory.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/cares_reinforcement_learning/util/NetworkFactory.py b/cares_reinforcement_learning/util/NetworkFactory.py index d20f8de7..de77ee02 100644 --- a/cares_reinforcement_learning/util/NetworkFactory.py +++ b/cares_reinforcement_learning/util/NetworkFactory.py @@ -1,7 +1,7 @@ import torch import logging - from cares_reinforcement_learning.util.configurations import AlgorithmConfig +import sys, inspect def create_DQN(observation_size, action_num, config: AlgorithmConfig): from cares_reinforcement_learning.algorithm.value import DQN @@ -170,21 +170,15 @@ def create_NaSATD3(observation_size, action_num, config: AlgorithmConfig): class NetworkFactory: def create_network(self, observation_size, action_num, config: AlgorithmConfig): algorithm = config.algorithm - if algorithm == "DQN": - return create_DQN(observation_size, action_num, config) - elif algorithm == "DoubleDQN": - return create_DDQN(observation_size, action_num, config) - elif algorithm == "DuelingDQN": - return create_DuelingDQN(observation_size, action_num, config) - elif algorithm == "PPO": - return create_PPO(observation_size, action_num, config) - elif algorithm == "DDPG": - return create_DDPG(observation_size, action_num, config) - elif algorithm == "SAC": - return create_SAC(observation_size, action_num, config) - elif algorithm == "TD3": - return create_TD3(observation_size, action_num, config) - elif algorithm == "NaSATD3": - return create_NaSATD3(observation_size, action_num, config) - logging.warn(f"Algorithm: {algorithm} is not in the default cares_rl factory") - return None + + ''' + Method taken from: + https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python + ''' + + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isfunction(obj): + if name == f"create_{algorithm}": + return obj(observation_size, action_num, config) + + raise Exception(f"CARES RL NetworkFactory: {algorithm} is not implemented ") From 7f04a593471e6ac5347586fcad03bc3730b683b1 Mon Sep 17 00:00:00 2001 From: retinfai Date: Wed, 1 Nov 2023 12:26:06 +1300 Subject: [PATCH 2/4] chore: change name to DoubleDQN --- cares_reinforcement_learning/util/NetworkFactory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cares_reinforcement_learning/util/NetworkFactory.py b/cares_reinforcement_learning/util/NetworkFactory.py index de77ee02..0a803dca 100644 --- a/cares_reinforcement_learning/util/NetworkFactory.py +++ b/cares_reinforcement_learning/util/NetworkFactory.py @@ -35,7 +35,7 @@ def create_DuelingDQN(observation_size, action_num, config: AlgorithmConfig): return agent -def create_DDQN(observation_size, action_num, config: AlgorithmConfig): +def create_DoubleDQN(observation_size, action_num, config: AlgorithmConfig): from cares_reinforcement_learning.algorithm.value import DoubleDQN from cares_reinforcement_learning.networks.DoubleDQN import Network From 322c09b20d297360468d697cd3777c25a5cf40b2 Mon Sep 17 00:00:00 2001 From: retinfai Date: Wed, 1 Nov 2023 12:31:58 +1300 Subject: [PATCH 3/4] test: update tests --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index af6a9cbb..069857dd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ def test_create_agents(): agent = create_DuelingDQN(10, 5, DuelingDQNConfig()) assert isinstance(agent, DQN), "Failed to create DuelingDQN agent" - agent = create_DDQN(10, 5, DoubleDQNConfig()) + agent = create_DoubleDQN(10, 5, DoubleDQNConfig()) assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent" agent = create_PPO(10, 5,PPOConfig()) From 7a3fe8a96ad550bc415fadcaf4914e8984f1b950 Mon Sep 17 00:00:00 2001 From: retinfai Date: Wed, 1 Nov 2023 12:45:08 +1300 Subject: [PATCH 4/4] test: update tests --- cares_reinforcement_learning/util/NetworkFactory.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cares_reinforcement_learning/util/NetworkFactory.py b/cares_reinforcement_learning/util/NetworkFactory.py index 0a803dca..791ef868 100644 --- a/cares_reinforcement_learning/util/NetworkFactory.py +++ b/cares_reinforcement_learning/util/NetworkFactory.py @@ -176,9 +176,14 @@ def create_network(self, observation_size, action_num, config: AlgorithmConfig): https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python ''' + agent = None + for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isfunction(obj): if name == f"create_{algorithm}": - return obj(observation_size, action_num, config) + agent = obj(observation_size, action_num, config) - raise Exception(f"CARES RL NetworkFactory: {algorithm} is not implemented ") + if agent is None: + logging.warn(f"Unkown failed to return None: returned {agent}") + + return agent