Skip to content

Commit

Permalink
Merge pull request #108 from UoA-CARES/dev/dynamic-network-factory
Browse files Browse the repository at this point in the history
feat: make network factory dynamic
  • Loading branch information
beardyFace authored Nov 1, 2023
2 parents eb94ebf + 7a3fe8a commit b52f714
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
39 changes: 19 additions & 20 deletions cares_reinforcement_learning/util/NetworkFactory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import logging

from cares_reinforcement_learning.util.configurations import AlgorithmConfig
import sys, inspect

def create_DQN(observation_size, action_num, config: AlgorithmConfig):
from cares_reinforcement_learning.algorithm.value import DQN
Expand Down Expand Up @@ -35,7 +35,7 @@ def create_DuelingDQN(observation_size, action_num, config: AlgorithmConfig):
return agent


def create_DDQN(observation_size, action_num, config: AlgorithmConfig):
def create_DoubleDQN(observation_size, action_num, config: AlgorithmConfig):
from cares_reinforcement_learning.algorithm.value import DoubleDQN
from cares_reinforcement_learning.networks.DoubleDQN import Network

Expand Down Expand Up @@ -170,21 +170,20 @@ def create_NaSATD3(observation_size, action_num, config: AlgorithmConfig):
class NetworkFactory:
def create_network(self, observation_size, action_num, config: AlgorithmConfig):
algorithm = config.algorithm
if algorithm == "DQN":
return create_DQN(observation_size, action_num, config)
elif algorithm == "DoubleDQN":
return create_DDQN(observation_size, action_num, config)
elif algorithm == "DuelingDQN":
return create_DuelingDQN(observation_size, action_num, config)
elif algorithm == "PPO":
return create_PPO(observation_size, action_num, config)
elif algorithm == "DDPG":
return create_DDPG(observation_size, action_num, config)
elif algorithm == "SAC":
return create_SAC(observation_size, action_num, config)
elif algorithm == "TD3":
return create_TD3(observation_size, action_num, config)
elif algorithm == "NaSATD3":
return create_NaSATD3(observation_size, action_num, config)
logging.warn(f"Algorithm: {algorithm} is not in the default cares_rl factory")
return None

'''
Method taken from:
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
'''

agent = None

for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isfunction(obj):
if name == f"create_{algorithm}":
agent = obj(observation_size, action_num, config)

if agent is None:
logging.warn(f"Unkown failed to return None: returned {agent}")

return agent
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_create_agents():
agent = create_DuelingDQN(10, 5, DuelingDQNConfig())
assert isinstance(agent, DQN), "Failed to create DuelingDQN agent"

agent = create_DDQN(10, 5, DoubleDQNConfig())
agent = create_DoubleDQN(10, 5, DoubleDQNConfig())
assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent"

agent = create_PPO(10, 5,PPOConfig())
Expand Down

0 comments on commit b52f714

Please sign in to comment.