diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index edc098a..1c059a1 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -8,6 +8,10 @@ # pylint disbale-next=unused-import # NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - list[type] = default +from pathlib import Path + +file_path = Path(__file__) +network_file_path = f"{file_path.parents[1]}/networks" class SubscriptableClass(BaseModel): @@ -200,6 +204,9 @@ class SACConfig(AlgorithmConfig): hidden_size_actor: list[int] = [256, 256] hidden_size_critic: list[int] = [256, 256] + actor_module: str = f"{network_file_path}/SAC/actor.py" + critic_module: str = f"{network_file_path}/SAC/critic.py" + class SACAEConfig(AlgorithmConfig): algorithm: str = Field("SACAE", Literal=True) diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index a821a05..8e29c19 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -4,18 +4,27 @@ """ import copy +import importlib.util import inspect import logging import sys -import cares_reinforcement_learning.util.helpers as hlp import cares_reinforcement_learning.util.configurations as acf +import cares_reinforcement_learning.util.helpers as hlp # Disable these as this is a deliberate use of dynamic imports # pylint: disable=import-outside-toplevel # pylint: disable=invalid-name +def import_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + def create_DQN(observation_size, action_num, config: acf.DQNConfig): from cares_reinforcement_learning.algorithm.value import DQN from cares_reinforcement_learning.networks.DQN import Network @@ -114,15 +123,27 @@ def create_DynaSAC(observation_size, action_num, config: acf.DynaSACConfig): def create_SAC(observation_size, action_num, config: acf.SACConfig): from cares_reinforcement_learning.algorithm.policy import SAC - from cares_reinforcement_learning.networks.SAC import Actor, Critic - actor = Actor( + actor_class = import_from_path( + "actor", + config.actor_module, + ).Actor + + critic_class = import_from_path( + "critic", + config.critic_module, + ).Critic + + actor = actor_class( observation_size, action_num, hidden_size=config.hidden_size_actor, log_std_bounds=config.log_std_bounds, ) - critic = Critic(observation_size, action_num, hidden_size=config.hidden_size_critic) + + critic = critic_class( + observation_size, action_num, hidden_size=config.hidden_size_critic + ) device = hlp.get_device() agent = SAC(