Skip to content

Commit

Permalink
1. Changing to adhere with naming convention.
Browse files Browse the repository at this point in the history
  • Loading branch information
qiaoting159753 committed Mar 10, 2024
1 parent 6cc67a4 commit c911bb1
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 26 deletions.
18 changes: 9 additions & 9 deletions cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn.functional as F


class MBRL_DYNA_SAC:
class DynaSAC:
"""
Use the Soft Actor Critic as the Actor Critic framework.
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.policy_update_freq = 1

@property
def alpha(self):
def _alpha(self):
"""
A variatble decide to what extend entropy shoud be valued.
"""
Expand All @@ -98,7 +98,7 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0):
self.actor_net.train()
return action

def true_train_policy(self, states, actions, rewards, next_states, dones):
def _true_train_policy(self, states, actions, rewards, next_states, dones):
"""
Train the policy with Model-Based Value Expansion. A family of MBRL.
Expand All @@ -110,7 +110,7 @@ def true_train_policy(self, states, actions, rewards, next_states, dones):
next_states, next_actions
)
target_q_values = (
torch.minimum(target_q_one, target_q_two) - self.alpha * next_log_pi
torch.minimum(target_q_one, target_q_two) - self._alpha * next_log_pi
)
q_target = rewards + self.gamma * (1 - dones) * target_q_values
q_target = q_target.detach()
Expand All @@ -129,7 +129,7 @@ def true_train_policy(self, states, actions, rewards, next_states, dones):
pi, first_log_p, _ = self.actor_net.sample(states)
qf1_pi, qf2_pi = self.critic_net(states, pi)
min_qf_pi = torch.minimum(qf1_pi, qf2_pi)
actor_loss = ((self.alpha * first_log_p) - min_qf_pi).mean()
actor_loss = ((self._alpha * first_log_p) - min_qf_pi).mean()

# Update the Actor
self.actor_net_optimiser.zero_grad()
Expand Down Expand Up @@ -225,17 +225,17 @@ def train_policy(self, experiences):
assert len(rewards.shape) == 2 and rewards.shape[1] == 1
assert len(next_states.shape) >= 2
# Step 2 train as usual
self.true_train_policy(
self._true_train_policy(
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
dones=dones,
)
# # # Step 3 Dyna add more data
self.dyna_generate_and_train(next_states=next_states)
self._dyna_generate_and_train(next_states=next_states)

def dyna_generate_and_train(self, next_states):
def _dyna_generate_and_train(self, next_states):
"""
Only off-policy Dyna will work.
:param next_states:
Expand Down Expand Up @@ -266,7 +266,7 @@ def dyna_generate_and_train(self, next_states):
# Pay attention to here! It is dones in the Cares RL Code!
pred_dones = torch.FloatTensor(np.zeros(pred_rs.shape)).to(self.device)
# states, actions, rewards, next_states, not_dones
self.true_train_policy(
self._true_train_policy(
pred_states, pred_actions, pred_rs, pred_n_states, pred_dones
)

Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/mbrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .DYNA_SAC import MBRL_DYNA_SAC
from .DYNA_SAC import DynaSAC

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from cares_reinforcement_learning.networks.world_models.ensemble_integrated import (
EnsembleWorldReward,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from torch import optim
import numpy as np
from cares_reinforcement_learning.util.helpers import normalize_obs_deltas
from cares_reinforcement_learning.networks.World_Models.simple_dynamics import (
Simple_Dynamics,
from cares_reinforcement_learning.networks.world_models.simple_dynamics import (
SimpleDynamics,
)
from cares_reinforcement_learning.networks.World_Models.simple_rewards import (
Simple_Reward,
from cares_reinforcement_learning.networks.world_models.simple_rewards import (
SimpleReward,
)


Expand All @@ -27,12 +27,12 @@ class IntegratedWorldModel:
"""

def __init__(self, observation_size, num_actions, hidden_size, lr=0.001):
self.dyna_network = Simple_Dynamics(
self.dyna_network = SimpleDynamics(
observation_size=observation_size,
num_actions=num_actions,
hidden_size=hidden_size,
)
self.reward_network = Simple_Reward(
self.reward_network = SimpleReward(
observation_size=observation_size,
num_actions=num_actions,
hidden_size=hidden_size,
Expand Down Expand Up @@ -95,7 +95,7 @@ def train_overall(self, states, actions, next_states, next_actions, next_rewards
self.all_optimizer.step()


class Ensemble_World_Reward:
class EnsembleWorldReward:
"""
Ensemble the integrated dynamic reward models. It works like a group of
experts. The predicted results can be used to estimate the uncertainty.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)


class Simple_Dynamics(nn.Module):
class SimpleDynamics(nn.Module):
"""
A world model with fully connected layers. It takes current states (s) and
current actions (a), and predict next states (s').
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cares_reinforcement_learning.util.helpers import weight_init


class Simple_Reward(nn.Module):
class SimpleReward(nn.Module):
def __init__(self, observation_size, num_actions, hidden_size):
"""
Note, This reward function is limited to 0 ~ 1 for dm_control.
Expand Down
8 changes: 4 additions & 4 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,21 @@ def create_MBRL_DYNA(observation_size, action_num, config: MBRL_DYNAConfig):
An extra world model is added.
"""
from cares_reinforcement_learning.algorithm.mbrl import MBRL_DYNA_SAC
from cares_reinforcement_learning.algorithm.mbrl import DynaSAC
from cares_reinforcement_learning.networks.SAC import Actor, Critic
from cares_reinforcement_learning.networks.World_Models import Ensemble_World_Reward
from cares_reinforcement_learning.networks.world_models import EnsembleWorldReward

actor = Actor(observation_size, action_num)
critic = Critic(observation_size, action_num)
world_model = Ensemble_World_Reward(
world_model = EnsembleWorldReward(
observation_size=observation_size,
num_actions=action_num,
num_models=config.num_models,
lr=config.world_model_lr,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent = MBRL_DYNA_SAC(
agent = DynaSAC(
actor_network=actor,
critic_network=critic,
world_network=world_model,
Expand Down

0 comments on commit c911bb1

Please sign in to comment.