diff --git a/cares_reinforcement_learning/util/NetworkFactory.py b/cares_reinforcement_learning/util/NetworkFactory.py index d20f8de7..791ef868 100644 --- a/cares_reinforcement_learning/util/NetworkFactory.py +++ b/cares_reinforcement_learning/util/NetworkFactory.py @@ -1,7 +1,7 @@ import torch import logging - from cares_reinforcement_learning.util.configurations import AlgorithmConfig +import sys, inspect def create_DQN(observation_size, action_num, config: AlgorithmConfig): from cares_reinforcement_learning.algorithm.value import DQN @@ -35,7 +35,7 @@ def create_DuelingDQN(observation_size, action_num, config: AlgorithmConfig): return agent -def create_DDQN(observation_size, action_num, config: AlgorithmConfig): +def create_DoubleDQN(observation_size, action_num, config: AlgorithmConfig): from cares_reinforcement_learning.algorithm.value import DoubleDQN from cares_reinforcement_learning.networks.DoubleDQN import Network @@ -170,21 +170,20 @@ def create_NaSATD3(observation_size, action_num, config: AlgorithmConfig): class NetworkFactory: def create_network(self, observation_size, action_num, config: AlgorithmConfig): algorithm = config.algorithm - if algorithm == "DQN": - return create_DQN(observation_size, action_num, config) - elif algorithm == "DoubleDQN": - return create_DDQN(observation_size, action_num, config) - elif algorithm == "DuelingDQN": - return create_DuelingDQN(observation_size, action_num, config) - elif algorithm == "PPO": - return create_PPO(observation_size, action_num, config) - elif algorithm == "DDPG": - return create_DDPG(observation_size, action_num, config) - elif algorithm == "SAC": - return create_SAC(observation_size, action_num, config) - elif algorithm == "TD3": - return create_TD3(observation_size, action_num, config) - elif algorithm == "NaSATD3": - return create_NaSATD3(observation_size, action_num, config) - logging.warn(f"Algorithm: {algorithm} is not in the default cares_rl factory") - return None + + ''' + Method taken from: + https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python + ''' + + agent = None + + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isfunction(obj): + if name == f"create_{algorithm}": + agent = obj(observation_size, action_num, config) + + if agent is None: + logging.warn(f"Unkown failed to return None: returned {agent}") + + return agent diff --git a/tests/test_utils.py b/tests/test_utils.py index af6a9cbb..069857dd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ def test_create_agents(): agent = create_DuelingDQN(10, 5, DuelingDQNConfig()) assert isinstance(agent, DQN), "Failed to create DuelingDQN agent" - agent = create_DDQN(10, 5, DoubleDQNConfig()) + agent = create_DoubleDQN(10, 5, DoubleDQNConfig()) assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent" agent = create_PPO(10, 5,PPOConfig())