-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into docs/add-examples-directory
- Loading branch information
Showing
4 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
181 changes: 181 additions & 0 deletions
181
cares_reinforcement_learning/algorithm/policy/LAPTD3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
""" | ||
Original Paper: https://arxiv.org/abs/2007.06049 | ||
""" | ||
|
||
import copy | ||
import logging | ||
import os | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
class LAPTD3: | ||
def __init__( | ||
self, | ||
actor_network, | ||
critic_network, | ||
gamma, | ||
tau, | ||
alpha, | ||
min_priority, | ||
action_num, | ||
actor_lr, | ||
critic_lr, | ||
device, | ||
): | ||
self.type = "policy" | ||
self.actor_net = actor_network.to(device) | ||
self.critic_net = critic_network.to(device) | ||
|
||
self.target_actor_net = copy.deepcopy(self.actor_net) # .to(device) | ||
self.target_critic_net = copy.deepcopy(self.critic_net) # .to(device) | ||
|
||
self.gamma = gamma | ||
self.tau = tau | ||
self.alpha = alpha | ||
self.min_priority = min_priority | ||
|
||
self.learn_counter = 0 | ||
self.policy_update_freq = 2 | ||
|
||
self.action_num = action_num | ||
self.device = device | ||
|
||
self.actor_net_optimiser = torch.optim.Adam( | ||
self.actor_net.parameters(), lr=actor_lr | ||
) | ||
self.critic_net_optimiser = torch.optim.Adam( | ||
self.critic_net.parameters(), lr=critic_lr | ||
) | ||
|
||
def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): | ||
self.actor_net.eval() | ||
with torch.no_grad(): | ||
state_tensor = torch.FloatTensor(state).to(self.device) | ||
state_tensor = state_tensor.unsqueeze(0) | ||
action = self.actor_net(state_tensor) | ||
action = action.cpu().data.numpy().flatten() | ||
if not evaluation: | ||
# this is part the TD3 too, add noise to the action | ||
noise = np.random.normal(0, scale=noise_scale, size=self.action_num) | ||
action = action + noise | ||
action = np.clip(action, -1, 1) | ||
self.actor_net.train() | ||
return action | ||
|
||
def huber(self, x): | ||
return torch.where( | ||
x < self.min_priority, 0.5 * x.pow(2), self.min_priority * x | ||
).mean() | ||
|
||
def train_policy(self, memory, batch_size): | ||
self.learn_counter += 1 | ||
|
||
experiences = memory.sample(batch_size) | ||
states, actions, rewards, next_states, dones, indices, weights = experiences | ||
|
||
batch_size = len(states) | ||
|
||
# Convert into tensor | ||
states = torch.FloatTensor(np.asarray(states)).to(self.device) | ||
actions = torch.FloatTensor(np.asarray(actions)).to(self.device) | ||
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) | ||
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) | ||
dones = torch.LongTensor(np.asarray(dones)).to(self.device) | ||
weights = torch.LongTensor(np.asarray(weights)).to(self.device) | ||
|
||
# Reshape to batch_size | ||
rewards = rewards.unsqueeze(0).reshape(batch_size, 1) | ||
dones = dones.unsqueeze(0).reshape(batch_size, 1) | ||
|
||
with torch.no_grad(): | ||
next_actions = self.target_actor_net(next_states) | ||
target_noise = 0.2 * torch.randn_like(next_actions) | ||
target_noise = torch.clamp(target_noise, -0.5, 0.5) | ||
next_actions = next_actions + target_noise | ||
next_actions = torch.clamp(next_actions, min=-1, max=1) | ||
|
||
target_q_values_one, target_q_values_two = self.target_critic_net( | ||
next_states, next_actions | ||
) | ||
target_q_values = torch.minimum(target_q_values_one, target_q_values_two) | ||
|
||
q_target = rewards + self.gamma * (1 - dones) * target_q_values | ||
|
||
q_values_one, q_values_two = self.critic_net(states, actions) | ||
|
||
td_loss_one = (target_q_values_one - q_target).abs() | ||
td_loss_two = (target_q_values_two - q_target).abs() | ||
|
||
critic_loss_one = F.mse_loss(q_values_one, q_target) | ||
critic_loss_two = F.mse_loss(q_values_two, q_target) | ||
|
||
critic_loss_total = self.huber(critic_loss_one) + self.huber(critic_loss_two) | ||
|
||
# Update the Critic | ||
self.critic_net_optimiser.zero_grad() | ||
torch.mean(critic_loss_total).backward() | ||
self.critic_net_optimiser.step() | ||
|
||
priorities = ( | ||
torch.max(td_loss_one, td_loss_two) | ||
.pow(self.alpha) | ||
.cpu() | ||
.data.numpy() | ||
.flatten() | ||
) | ||
|
||
if self.learn_counter % self.policy_update_freq == 0: | ||
# actor_q_one, actor_q_two = self.critic_net(states, self.actor_net(states)) | ||
# actor_q_values = torch.minimum(actor_q_one, actor_q_two) | ||
|
||
# Update Actor | ||
actor_q_values, _ = self.critic_net(states, self.actor_net(states)) | ||
actor_loss = -actor_q_values.mean() | ||
|
||
self.actor_net_optimiser.zero_grad() | ||
actor_loss.backward() | ||
self.actor_net_optimiser.step() | ||
|
||
# Update target network params | ||
for target_param, param in zip( | ||
self.target_critic_net.Q1.parameters(), self.critic_net.Q1.parameters() | ||
): | ||
target_param.data.copy_( | ||
param.data * self.tau + target_param.data * (1.0 - self.tau) | ||
) | ||
|
||
for target_param, param in zip( | ||
self.target_critic_net.Q2.parameters(), self.critic_net.Q2.parameters() | ||
): | ||
target_param.data.copy_( | ||
param.data * self.tau + target_param.data * (1.0 - self.tau) | ||
) | ||
|
||
for target_param, param in zip( | ||
self.target_actor_net.parameters(), self.actor_net.parameters() | ||
): | ||
target_param.data.copy_( | ||
param.data * self.tau + target_param.data * (1.0 - self.tau) | ||
) | ||
|
||
memory.update_priorities(indices, priorities) | ||
|
||
def save_models(self, filename, filepath="models"): | ||
path = f"{filepath}/models" if filepath != "models" else filepath | ||
dir_exists = os.path.exists(path) | ||
|
||
if not dir_exists: | ||
os.makedirs(path) | ||
|
||
torch.save(self.actor_net.state_dict(), f"{path}/{filename}_actor.pht") | ||
torch.save(self.critic_net.state_dict(), f"{path}/{filename}_critic.pht") | ||
logging.info("models has been saved...") | ||
|
||
def load_models(self, filepath, filename): | ||
path = f"{filepath}/models" if filepath != "models" else filepath | ||
|
||
self.actor_net.load_state_dict(torch.load(f"{path}/{filename}_actor.pht")) | ||
self.critic_net.load_state_dict(torch.load(f"{path}/{filename}_critic.pht")) | ||
logging.info("models has been loaded...") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
from .CTD4 import CTD4 | ||
from .RDTD3 import RDTD3 | ||
from .PERTD3 import PERTD3 | ||
from .LAPTD3 import LAPTD3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters