Skip to content

Commit

Permalink
Merge pull request #103 from UoA-CARES/dev/configurations-generic
Browse files Browse the repository at this point in the history
Added noise decay into policy loop
  • Loading branch information
dvalenciar authored Oct 26, 2023
2 parents 8289b50 + 01a5bfc commit 266ac95
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/policy/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,

self.device = device

def select_action_from_policy(self, state, evaluation=None):
def select_action_from_policy(self, state, evaluation=None, noise_scale=0):
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state)
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/policy/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self,
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=1e-3)

def select_action_from_policy(self, state, evaluation=False):
def select_action_from_policy(self, state, evaluation=False, noise_scale=0):
# note that when evaluating this algorithm we need to select mu as action so _, _, action = self.actor_net.sample(state_tensor)
self.actor_net.eval()
with torch.no_grad():
Expand Down
9 changes: 8 additions & 1 deletion cares_reinforcement_learning/train_loops/policy_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def policy_based_train(env, agent, memory, record, train_config: TrainingConfig,
max_steps_exploration = train_config.max_steps_exploration
number_steps_per_evaluation = train_config.number_steps_per_evaluation

min_noise = alg_config.min_noise if hasattr(alg_config, "min_noise") else 0
noise_decay = alg_config.noise_decay if hasattr(alg_config, "noise_decay") else 0
noise_scale = alg_config.noise_scale if hasattr(alg_config, "noise_scale") else 0

logging.info(f"Training {max_steps_training} Exploration {max_steps_exploration} Evaluation {number_steps_per_evaluation}")

batch_size = train_config.batch_size
Expand All @@ -84,8 +88,11 @@ def policy_based_train(env, agent, memory, record, train_config: TrainingConfig,
# algorithm range [-1, 1] - note for DMCS this is redudenant but required for openai
action = hlp.normalize(action_env, env.max_action_value, env.min_action_value)
else:
noise_scale *= noise_decay
noise_scale = max(min_noise, noise_scale)

# algorithm range [-1, 1]
action = agent.select_action_from_policy(state)
action = agent.select_action_from_policy(state, noise_scale=noise_scale)
# mapping to env range [e.g. -2 , 2 for pendulum] - note for DMCS this is redudenant but required for openai
action_env = hlp.denormalize(action, env.max_action_value, env.min_action_value)

Expand Down
16 changes: 8 additions & 8 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,25 @@ class DoubleDQNConfig(AlgorithmConfig):
exploration_min: Optional[float] = 1e-3
exploration_decay: Optional[float] = 0.95

class DDPGConfig(AlgorithmConfig):
algorithm: str = Field("DDPG", Literal=True)
class PPOConfig(AlgorithmConfig):
algorithm: str = Field("PPO", Literal=True)
actor_lr: Optional[float] = 1e-4
critic_lr: Optional[float] = 1e-3

gamma: Optional[float] = 0.99
tau: Optional[float] = 0.005
max_steps_per_batch: Optional[int] = 5000

memory: Optional[str] = "MemoryBuffer"
memory: str = Field("MemoryBuffer", Literal=True)

class PPOConfig(AlgorithmConfig):
algorithm: str = Field("PPO", Literal=True)
class DDPGConfig(AlgorithmConfig):
algorithm: str = Field("DDPG", Literal=True)
actor_lr: Optional[float] = 1e-4
critic_lr: Optional[float] = 1e-3

gamma: Optional[float] = 0.99
max_steps_per_batch: Optional[int] = 5000
tau: Optional[float] = 0.005

memory: str = Field("MemoryBuffer", Literal=True)
memory: Optional[str] = "MemoryBuffer"

class TD3Config(AlgorithmConfig):
algorithm: str = Field("TD3", Literal=True)
Expand Down

0 comments on commit 266ac95

Please sign in to comment.