Skip to content

Commit

Permalink
minor errors with weight mse fixed (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace authored Apr 5, 2024
1 parent 8506f46 commit b6e8b20
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
10 changes: 6 additions & 4 deletions cares_reinforcement_learning/algorithm/policy/PERTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ def train_policy(self, experiences):
td_loss_one = (target_q_values_one - q_target).abs()
td_loss_two = (target_q_values_two - q_target).abs()

critic_loss_one = F.mse_loss(q_values_one, q_target)
critic_loss_two = F.mse_loss(q_values_two, q_target)
critic_loss_one = F.mse_loss(q_values_one, q_target, reduction="none")
critic_loss_two = F.mse_loss(q_values_two, q_target, reduction="none")

critic_loss_total = critic_loss_one * weights + critic_loss_two * weights
critic_loss_total = (critic_loss_one * weights).mean() + (
critic_loss_two * weights
).mean()

# Update the Critic
self.critic_net_optimiser.zero_grad()
torch.mean(critic_loss_total).backward()
critic_loss_total.backward()
self.critic_net_optimiser.step()

priorities = (
Expand Down
6 changes: 4 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/RDTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,13 @@ def train_policy(self, experience):
+ self.scale_s * diff_next_states_two
)

critic_loss_total = critic_one_loss * weights + critic_two_loss * weights
critic_loss_total = (critic_one_loss * weights).mean() + (
critic_two_loss * weights
).mean()

# train critic
self.critic_net_optimiser.zero_grad()
torch.mean(critic_loss_total).backward()
critic_loss_total.backward()
self.critic_net_optimiser.step()
############################

Expand Down

0 comments on commit b6e8b20

Please sign in to comment.