Skip to content

Commit

Permalink
Added LAPTD3 (#137)
Browse files Browse the repository at this point in the history
* Added LAPTD3

* Actually add LAPTD3

* pushed sample into algorithms

* updated to memory in algorithm
  • Loading branch information
beardyFace authored Apr 10, 2024
1 parent 6a3c465 commit e755bb0
Showing 4 changed files with 219 additions and 0 deletions.
181 changes: 181 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/LAPTD3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
Original Paper: https://arxiv.org/abs/2007.06049
"""

import copy
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F


class LAPTD3:
def __init__(
self,
actor_network,
critic_network,
gamma,
tau,
alpha,
min_priority,
action_num,
actor_lr,
critic_lr,
device,
):
self.type = "policy"
self.actor_net = actor_network.to(device)
self.critic_net = critic_network.to(device)

self.target_actor_net = copy.deepcopy(self.actor_net) # .to(device)
self.target_critic_net = copy.deepcopy(self.critic_net) # .to(device)

self.gamma = gamma
self.tau = tau
self.alpha = alpha
self.min_priority = min_priority

self.learn_counter = 0
self.policy_update_freq = 2

self.action_num = action_num
self.device = device

self.actor_net_optimiser = torch.optim.Adam(
self.actor_net.parameters(), lr=actor_lr
)
self.critic_net_optimiser = torch.optim.Adam(
self.critic_net.parameters(), lr=critic_lr
)

def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1):
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
action = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
if not evaluation:
# this is part the TD3 too, add noise to the action
noise = np.random.normal(0, scale=noise_scale, size=self.action_num)
action = action + noise
action = np.clip(action, -1, 1)
self.actor_net.train()
return action

def huber(self, x):
return torch.where(
x < self.min_priority, 0.5 * x.pow(2), self.min_priority * x
).mean()

def train_policy(self, memory, batch_size):
self.learn_counter += 1

experiences = memory.sample(batch_size)
states, actions, rewards, next_states, dones, indices, weights = experiences

batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones = torch.LongTensor(np.asarray(dones)).to(self.device)
weights = torch.LongTensor(np.asarray(weights)).to(self.device)

# Reshape to batch_size
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

with torch.no_grad():
next_actions = self.target_actor_net(next_states)
target_noise = 0.2 * torch.randn_like(next_actions)
target_noise = torch.clamp(target_noise, -0.5, 0.5)
next_actions = next_actions + target_noise
next_actions = torch.clamp(next_actions, min=-1, max=1)

target_q_values_one, target_q_values_two = self.target_critic_net(
next_states, next_actions
)
target_q_values = torch.minimum(target_q_values_one, target_q_values_two)

q_target = rewards + self.gamma * (1 - dones) * target_q_values

q_values_one, q_values_two = self.critic_net(states, actions)

td_loss_one = (target_q_values_one - q_target).abs()
td_loss_two = (target_q_values_two - q_target).abs()

critic_loss_one = F.mse_loss(q_values_one, q_target)
critic_loss_two = F.mse_loss(q_values_two, q_target)

critic_loss_total = self.huber(critic_loss_one) + self.huber(critic_loss_two)

# Update the Critic
self.critic_net_optimiser.zero_grad()
torch.mean(critic_loss_total).backward()
self.critic_net_optimiser.step()

priorities = (
torch.max(td_loss_one, td_loss_two)
.pow(self.alpha)
.cpu()
.data.numpy()
.flatten()
)

if self.learn_counter % self.policy_update_freq == 0:
# actor_q_one, actor_q_two = self.critic_net(states, self.actor_net(states))
# actor_q_values = torch.minimum(actor_q_one, actor_q_two)

# Update Actor
actor_q_values, _ = self.critic_net(states, self.actor_net(states))
actor_loss = -actor_q_values.mean()

self.actor_net_optimiser.zero_grad()
actor_loss.backward()
self.actor_net_optimiser.step()

# Update target network params
for target_param, param in zip(
self.target_critic_net.Q1.parameters(), self.critic_net.Q1.parameters()
):
target_param.data.copy_(
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

for target_param, param in zip(
self.target_critic_net.Q2.parameters(), self.critic_net.Q2.parameters()
):
target_param.data.copy_(
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

for target_param, param in zip(
self.target_actor_net.parameters(), self.actor_net.parameters()
):
target_param.data.copy_(
param.data * self.tau + target_param.data * (1.0 - self.tau)
)

memory.update_priorities(indices, priorities)

def save_models(self, filename, filepath="models"):
path = f"{filepath}/models" if filepath != "models" else filepath
dir_exists = os.path.exists(path)

if not dir_exists:
os.makedirs(path)

torch.save(self.actor_net.state_dict(), f"{path}/{filename}_actor.pht")
torch.save(self.critic_net.state_dict(), f"{path}/{filename}_critic.pht")
logging.info("models has been saved...")

def load_models(self, filepath, filename):
path = f"{filepath}/models" if filepath != "models" else filepath

self.actor_net.load_state_dict(torch.load(f"{path}/{filename}_actor.pht"))
self.critic_net.load_state_dict(torch.load(f"{path}/{filename}_critic.pht"))
logging.info("models has been loaded...")
1 change: 1 addition & 0 deletions cares_reinforcement_learning/algorithm/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -6,3 +6,4 @@
from .CTD4 import CTD4
from .RDTD3 import RDTD3
from .PERTD3 import PERTD3
from .LAPTD3 import LAPTD3
14 changes: 14 additions & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
@@ -184,3 +184,17 @@ class PERTD3Config(AlgorithmConfig):

noise_scale: Optional[float] = 0.1
noise_decay: Optional[float] = 1


class LAPTD3Config(AlgorithmConfig):
algorithm: str = Field("LAPTD3", Literal=True)

actor_lr: Optional[float] = 1e-4
critic_lr: Optional[float] = 1e-3
gamma: Optional[float] = 0.99
tau: Optional[float] = 0.005
alpha: Optional[float] = 0.6
min_priority: Optional[float] = 1.0

noise_scale: Optional[float] = 0.1
noise_decay: Optional[float] = 1
23 changes: 23 additions & 0 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
@@ -284,6 +284,29 @@ def create_PERTD3(observation_size, action_num, config: AlgorithmConfig):
return agent


def create_LAPTD3(observation_size, action_num, config: AlgorithmConfig):
from cares_reinforcement_learning.algorithm.policy import LAPTD3
from cares_reinforcement_learning.networks.TD3 import Actor, Critic

actor = Actor(observation_size, action_num)
critic = Critic(observation_size, action_num)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = LAPTD3(
actor_network=actor,
critic_network=critic,
gamma=config.gamma,
tau=config.tau,
alpha=config.alpha,
min_priority=config.min_priority,
action_num=action_num,
actor_lr=config.actor_lr,
critic_lr=config.critic_lr,
device=device,
)
return agent


class NetworkFactory:
def create_network(self, observation_size, action_num, config: AlgorithmConfig):
algorithm = config.algorithm

0 comments on commit e755bb0

Please sign in to comment.