diff --git a/cares_reinforcement_learning/algorithm/policy/LAPTD3.py b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py new file mode 100644 index 00000000..589fcc5d --- /dev/null +++ b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py @@ -0,0 +1,181 @@ +""" +Original Paper: https://arxiv.org/abs/2007.06049 +""" + +import copy +import logging +import os +import numpy as np +import torch +import torch.nn.functional as F + + +class LAPTD3: + def __init__( + self, + actor_network, + critic_network, + gamma, + tau, + alpha, + min_priority, + action_num, + actor_lr, + critic_lr, + device, + ): + self.type = "policy" + self.actor_net = actor_network.to(device) + self.critic_net = critic_network.to(device) + + self.target_actor_net = copy.deepcopy(self.actor_net) # .to(device) + self.target_critic_net = copy.deepcopy(self.critic_net) # .to(device) + + self.gamma = gamma + self.tau = tau + self.alpha = alpha + self.min_priority = min_priority + + self.learn_counter = 0 + self.policy_update_freq = 2 + + self.action_num = action_num + self.device = device + + self.actor_net_optimiser = torch.optim.Adam( + self.actor_net.parameters(), lr=actor_lr + ) + self.critic_net_optimiser = torch.optim.Adam( + self.critic_net.parameters(), lr=critic_lr + ) + + def select_action_from_policy(self, state, evaluation=False, noise_scale=0.1): + self.actor_net.eval() + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).to(self.device) + state_tensor = state_tensor.unsqueeze(0) + action = self.actor_net(state_tensor) + action = action.cpu().data.numpy().flatten() + if not evaluation: + # this is part the TD3 too, add noise to the action + noise = np.random.normal(0, scale=noise_scale, size=self.action_num) + action = action + noise + action = np.clip(action, -1, 1) + self.actor_net.train() + return action + + def huber(self, x): + return torch.where( + x < self.min_priority, 0.5 * x.pow(2), self.min_priority * x + ).mean() + + def train_policy(self, memory, batch_size): + self.learn_counter += 1 + + experiences = memory.sample(batch_size) + states, actions, rewards, next_states, dones, indices, weights = experiences + + batch_size = len(states) + + # Convert into tensor + states = torch.FloatTensor(np.asarray(states)).to(self.device) + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) + rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) + next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) + dones = torch.LongTensor(np.asarray(dones)).to(self.device) + weights = torch.LongTensor(np.asarray(weights)).to(self.device) + + # Reshape to batch_size + rewards = rewards.unsqueeze(0).reshape(batch_size, 1) + dones = dones.unsqueeze(0).reshape(batch_size, 1) + + with torch.no_grad(): + next_actions = self.target_actor_net(next_states) + target_noise = 0.2 * torch.randn_like(next_actions) + target_noise = torch.clamp(target_noise, -0.5, 0.5) + next_actions = next_actions + target_noise + next_actions = torch.clamp(next_actions, min=-1, max=1) + + target_q_values_one, target_q_values_two = self.target_critic_net( + next_states, next_actions + ) + target_q_values = torch.minimum(target_q_values_one, target_q_values_two) + + q_target = rewards + self.gamma * (1 - dones) * target_q_values + + q_values_one, q_values_two = self.critic_net(states, actions) + + td_loss_one = (target_q_values_one - q_target).abs() + td_loss_two = (target_q_values_two - q_target).abs() + + critic_loss_one = F.mse_loss(q_values_one, q_target) + critic_loss_two = F.mse_loss(q_values_two, q_target) + + critic_loss_total = self.huber(critic_loss_one) + self.huber(critic_loss_two) + + # Update the Critic + self.critic_net_optimiser.zero_grad() + torch.mean(critic_loss_total).backward() + self.critic_net_optimiser.step() + + priorities = ( + torch.max(td_loss_one, td_loss_two) + .pow(self.alpha) + .cpu() + .data.numpy() + .flatten() + ) + + if self.learn_counter % self.policy_update_freq == 0: + # actor_q_one, actor_q_two = self.critic_net(states, self.actor_net(states)) + # actor_q_values = torch.minimum(actor_q_one, actor_q_two) + + # Update Actor + actor_q_values, _ = self.critic_net(states, self.actor_net(states)) + actor_loss = -actor_q_values.mean() + + self.actor_net_optimiser.zero_grad() + actor_loss.backward() + self.actor_net_optimiser.step() + + # Update target network params + for target_param, param in zip( + self.target_critic_net.Q1.parameters(), self.critic_net.Q1.parameters() + ): + target_param.data.copy_( + param.data * self.tau + target_param.data * (1.0 - self.tau) + ) + + for target_param, param in zip( + self.target_critic_net.Q2.parameters(), self.critic_net.Q2.parameters() + ): + target_param.data.copy_( + param.data * self.tau + target_param.data * (1.0 - self.tau) + ) + + for target_param, param in zip( + self.target_actor_net.parameters(), self.actor_net.parameters() + ): + target_param.data.copy_( + param.data * self.tau + target_param.data * (1.0 - self.tau) + ) + + memory.update_priorities(indices, priorities) + + def save_models(self, filename, filepath="models"): + path = f"{filepath}/models" if filepath != "models" else filepath + dir_exists = os.path.exists(path) + + if not dir_exists: + os.makedirs(path) + + torch.save(self.actor_net.state_dict(), f"{path}/{filename}_actor.pht") + torch.save(self.critic_net.state_dict(), f"{path}/{filename}_critic.pht") + logging.info("models has been saved...") + + def load_models(self, filepath, filename): + path = f"{filepath}/models" if filepath != "models" else filepath + + self.actor_net.load_state_dict(torch.load(f"{path}/{filename}_actor.pht")) + self.critic_net.load_state_dict(torch.load(f"{path}/{filename}_critic.pht")) + logging.info("models has been loaded...") diff --git a/cares_reinforcement_learning/algorithm/policy/__init__.py b/cares_reinforcement_learning/algorithm/policy/__init__.py index a2a6deb8..af6ced8a 100644 --- a/cares_reinforcement_learning/algorithm/policy/__init__.py +++ b/cares_reinforcement_learning/algorithm/policy/__init__.py @@ -6,3 +6,4 @@ from .CTD4 import CTD4 from .RDTD3 import RDTD3 from .PERTD3 import PERTD3 +from .LAPTD3 import LAPTD3 diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index d2b91559..571a4bca 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -184,3 +184,17 @@ class PERTD3Config(AlgorithmConfig): noise_scale: Optional[float] = 0.1 noise_decay: Optional[float] = 1 + + +class LAPTD3Config(AlgorithmConfig): + algorithm: str = Field("LAPTD3", Literal=True) + + actor_lr: Optional[float] = 1e-4 + critic_lr: Optional[float] = 1e-3 + gamma: Optional[float] = 0.99 + tau: Optional[float] = 0.005 + alpha: Optional[float] = 0.6 + min_priority: Optional[float] = 1.0 + + noise_scale: Optional[float] = 0.1 + noise_decay: Optional[float] = 1 diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 7a07e2cc..5d642c62 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -284,6 +284,29 @@ def create_PERTD3(observation_size, action_num, config: AlgorithmConfig): return agent +def create_LAPTD3(observation_size, action_num, config: AlgorithmConfig): + from cares_reinforcement_learning.algorithm.policy import LAPTD3 + from cares_reinforcement_learning.networks.TD3 import Actor, Critic + + actor = Actor(observation_size, action_num) + critic = Critic(observation_size, action_num) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + agent = LAPTD3( + actor_network=actor, + critic_network=critic, + gamma=config.gamma, + tau=config.tau, + alpha=config.alpha, + min_priority=config.min_priority, + action_num=action_num, + actor_lr=config.actor_lr, + critic_lr=config.critic_lr, + device=device, + ) + return agent + + class NetworkFactory: def create_network(self, observation_size, action_num, config: AlgorithmConfig): algorithm = config.algorithm