Skip to content

Commit

Permalink
Example of dynamic loading of SAC actor and critic
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Nov 20, 2024
1 parent 6e48109 commit fe29735
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
7 changes: 7 additions & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
# pylint disbale-next=unused-import

# NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - list[type] = default
from pathlib import Path

file_path = Path(__file__)
network_file_path = f"{file_path.parents[1]}/networks"


class SubscriptableClass(BaseModel):
Expand Down Expand Up @@ -200,6 +204,9 @@ class SACConfig(AlgorithmConfig):
hidden_size_actor: list[int] = [256, 256]
hidden_size_critic: list[int] = [256, 256]

actor_module: str = f"{network_file_path}/SAC/actor.py"
critic_module: str = f"{network_file_path}/SAC/critic.py"


class SACAEConfig(AlgorithmConfig):
algorithm: str = Field("SACAE", Literal=True)
Expand Down
29 changes: 25 additions & 4 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
"""

import copy
import importlib.util
import inspect
import logging
import sys

import cares_reinforcement_learning.util.helpers as hlp
import cares_reinforcement_learning.util.configurations as acf
import cares_reinforcement_learning.util.helpers as hlp

# Disable these as this is a deliberate use of dynamic imports
# pylint: disable=import-outside-toplevel
# pylint: disable=invalid-name


def import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module


def create_DQN(observation_size, action_num, config: acf.DQNConfig):
from cares_reinforcement_learning.algorithm.value import DQN
from cares_reinforcement_learning.networks.DQN import Network
Expand Down Expand Up @@ -114,15 +123,27 @@ def create_DynaSAC(observation_size, action_num, config: acf.DynaSACConfig):

def create_SAC(observation_size, action_num, config: acf.SACConfig):
from cares_reinforcement_learning.algorithm.policy import SAC
from cares_reinforcement_learning.networks.SAC import Actor, Critic

actor = Actor(
actor_class = import_from_path(
"actor",
config.actor_module,
).Actor

critic_class = import_from_path(
"critic",
config.critic_module,
).Critic

actor = actor_class(
observation_size,
action_num,
hidden_size=config.hidden_size_actor,
log_std_bounds=config.log_std_bounds,
)
critic = Critic(observation_size, action_num, hidden_size=config.hidden_size_critic)

critic = critic_class(
observation_size, action_num, hidden_size=config.hidden_size_critic
)

device = hlp.get_device()
agent = SAC(
Expand Down

0 comments on commit fe29735

Please sign in to comment.