Skip to content

Commit

Permalink
Merge pull request #112 from UoA-CARES/dev/update_sac_to_the_paper
Browse files Browse the repository at this point in the history
Dev/update sac to the paper
  • Loading branch information
dvalenciar authored Dec 8, 2023
2 parents b52f714 + 1d2d356 commit 8717ce4
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 37 deletions.
3 changes: 1 addition & 2 deletions cares_reinforcement_learning/algorithm/policy/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self,
self.type = "policy"
self.actor_net = actor_network.to(device) # this may be called policy_net in other implementations
self.critic_net = critic_network.to(device) # this may be called soft_q_net in other implementations

self.target_critic_net = copy.deepcopy(self.critic_net).to(device)

self.gamma = gamma
Expand All @@ -44,7 +43,7 @@ def __init__(self,
self.actor_net_optimiser = torch.optim.Adam(self.actor_net.parameters(), lr=actor_lr)
self.critic_net_optimiser = torch.optim.Adam(self.critic_net.parameters(), lr=critic_lr)

init_temperature = 0.01
init_temperature = 1.0 # Set to initial alpha to 1.0 according to other baselines.
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=1e-3)
Expand Down
160 changes: 128 additions & 32 deletions cares_reinforcement_learning/networks/SAC/Actor.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,145 @@


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch import distributions as pyd
import math


class Actor(nn.Module):
def __init__(self, observation_size, num_actions):
super(Actor, self).__init__()
# class Actor(nn.Module):
# def __init__(self, observation_size, num_actions):
# super(Actor, self).__init__()
#
# self.hidden_size = [1024, 1024]
# self.log_sig_min = -20
# self.log_sig_max = 2
#
# self.h_linear_1 = nn.Linear(in_features=observation_size, out_features=self.hidden_size[0])
# self.h_linear_2 = nn.Linear(in_features=self.hidden_size[0], out_features=self.hidden_size[1])
#
# self.mean_linear = nn.Linear(in_features=self.hidden_size[1], out_features=num_actions)
# self.log_std_linear = nn.Linear(in_features=self.hidden_size[1], out_features=num_actions)
#
# def forward(self, state):
# x = F.relu(self.h_linear_1(state))
# x = F.relu(self.h_linear_2(x))
#
# mean = self.mean_linear(x)
# log_std = self.log_std_linear(x)
# log_std = torch.clamp(log_std, min=self.log_sig_min, max=self.log_sig_max)
#
# return mean, log_std
#
# def sample(self, state):
# mean, log_std = self.forward(state)
# std = log_std.exp()
# normal = Normal(mean, std)
#
# x_t = normal.rsample() # for re-parameterization trick (mean + std * N(0,1))
# y_t = torch.tanh(x_t)
# action = y_t
#
# epsilon = 1e-6
# log_prob = normal.log_prob(x_t)
# log_prob -= torch.log((1 - y_t.pow(2)) + epsilon)
# log_prob = log_prob.sum(1, keepdim=True)
# mean = torch.tanh(mean)
#
# return action, log_prob, mean


class TanhTransform(pyd.transforms.Transform):
r"""
Transform via the mapping :math:`y = \tanh(x)`.
It is equivalent to
```
ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
```
However this might not be numerically stable, thus it is recommended to use `TanhTransform`
instead.
Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
"""
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1

def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)

@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())

self.hidden_size = [1024, 1024]
self.log_sig_min = -20
self.log_sig_max = 2
def __eq__(self, other):
return isinstance(other, TanhTransform)

self.h_linear_1 = nn.Linear(in_features=observation_size, out_features=self.hidden_size[0])
self.h_linear_2 = nn.Linear(in_features=self.hidden_size[0], out_features=self.hidden_size[1])
def _call(self, x):
return x.tanh()

self.mean_linear = nn.Linear(in_features=self.hidden_size[1], out_features=num_actions)
self.log_std_linear = nn.Linear(in_features=self.hidden_size[1], out_features=num_actions)
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)

def forward(self, state):
x = F.relu(self.h_linear_1(state))
x = F.relu(self.h_linear_2(x))
def log_abs_det_jacobian(self, x, y):
# This function is often used to compute the log
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2. * (math.log(2.) - x - F.softplus(-2. * x))

mean = self.mean_linear(x)

class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
# a = tanh(u)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms, validate_args=False)

@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu


class Actor(nn.Module):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""
def __init__(self, state_dim, action_dim):
super().__init__()
self.hidden_size = [256, 256]
self.log_std_bounds = [-20, 2]
# Two hidden layers, 256 on each
self.linear1 = nn.Linear(state_dim, self.hidden_size[0])
self.linear2 = nn.Linear(self.hidden_size[0], self.hidden_size[1])
self.mean_linear = nn.Linear(self.hidden_size[1], action_dim)
self.log_std_linear = nn.Linear(self.hidden_size[1], action_dim)
# self.apply(weight_init)

def sample(self, obs):
x = F.relu(self.linear1(obs))
x = F.relu(self.linear2(x))
mu = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, min=self.log_sig_min, max=self.log_sig_max)

return mean, log_std
# Bound the action to finite interval.
# Apply an invertible squashing function: tanh
# employ the change of variables formula to compute the likelihoods of the bounded actions

# constrain log_std inside [log_std_min, log_std_max]
log_std = torch.tanh(log_std)

def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
log_std_min, log_std_max = self.log_std_bounds
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)

x_t = normal.rsample() # for re-parameterization trick (mean + std * N(0,1))
y_t = torch.tanh(x_t)
action = y_t
std = log_std.exp()

epsilon = 1e-6
log_prob = normal.log_prob(x_t)
log_prob -= torch.log((1 - y_t.pow(2)) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean)
dist = SquashedNormal(mu, std)
sample = dist.rsample()
log_pi = dist.log_prob(sample).sum(-1, keepdim=True)

return action, log_prob, mean
return sample, log_pi, dist.mean
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/SAC/Critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Critic(nn.Module):
def __init__(self, observation_size, num_actions):
super(Critic, self).__init__()

self.hidden_size = [1024, 1024]
self.hidden_size = [256, 256]

# Q1 architecture
self.h_linear_1 = nn.Linear(observation_size + num_actions, self.hidden_size[0])
Expand Down
4 changes: 2 additions & 2 deletions cares_rl_configs/SAC_algorithm_config.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"algorithm": "SAC",

"actor_lr": 1e-4,
"critic_lr": 1e-3,
"actor_lr": 3e-4,
"critic_lr": 3e-4,

"gamma": 0.99,
"tau": 0.005
Expand Down
17 changes: 17 additions & 0 deletions cares_rl_configs/sac_training_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"_comment": "Different from other algorithms, SAC do automatic exploration",

"seeds": [10, 25, 35, 45, 55],

"G": 1,
"batch_size": 256,

"max_steps_exploration": 0,
"max_steps_training": 1000000,

"number_steps_per_evaluation": 10000,
"number_eval_episodes": 10,

"plot_frequency": 100,
"checkpoint_frequency": 100
}

0 comments on commit 8717ce4

Please sign in to comment.