diff --git a/cares_reinforcement_learning/algorithm/policy/PERTD3.py b/cares_reinforcement_learning/algorithm/policy/PERTD3.py index 727f4601..e5db0baf 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/PERTD3.py @@ -102,14 +102,16 @@ def train_policy(self, experiences): td_loss_one = (target_q_values_one - q_target).abs() td_loss_two = (target_q_values_two - q_target).abs() - critic_loss_one = F.mse_loss(q_values_one, q_target) - critic_loss_two = F.mse_loss(q_values_two, q_target) + critic_loss_one = F.mse_loss(q_values_one, q_target, reduction="none") + critic_loss_two = F.mse_loss(q_values_two, q_target, reduction="none") - critic_loss_total = critic_loss_one * weights + critic_loss_two * weights + critic_loss_total = (critic_loss_one * weights).mean() + ( + critic_loss_two * weights + ).mean() # Update the Critic self.critic_net_optimiser.zero_grad() - torch.mean(critic_loss_total).backward() + critic_loss_total.backward() self.critic_net_optimiser.step() priorities = ( diff --git a/cares_reinforcement_learning/algorithm/policy/RDTD3.py b/cares_reinforcement_learning/algorithm/policy/RDTD3.py index 49671fc9..cb9c9534 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/RDTD3.py @@ -156,11 +156,13 @@ def train_policy(self, experience): + self.scale_s * diff_next_states_two ) - critic_loss_total = critic_one_loss * weights + critic_two_loss * weights + critic_loss_total = (critic_one_loss * weights).mean() + ( + critic_two_loss * weights + ).mean() # train critic self.critic_net_optimiser.zero_grad() - torch.mean(critic_loss_total).backward() + critic_loss_total.backward() self.critic_net_optimiser.step() ############################