diff --git a/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py b/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py index d49fb9c2..6b0e56cc 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py +++ b/cares_reinforcement_learning/algorithm/mbrl/DYNA_SAC.py @@ -13,7 +13,7 @@ import torch.nn.functional as F -class MBRL_DYNA_SAC: +class DynaSAC: """ Use the Soft Actor Critic as the Actor Critic framework. @@ -72,7 +72,7 @@ def __init__( self.policy_update_freq = 1 @property - def alpha(self): + def _alpha(self): """ A variatble decide to what extend entropy shoud be valued. """ @@ -98,7 +98,7 @@ def select_action_from_policy(self, state, evaluation=False, noise_scale=0): self.actor_net.train() return action - def true_train_policy(self, states, actions, rewards, next_states, dones): + def _true_train_policy(self, states, actions, rewards, next_states, dones): """ Train the policy with Model-Based Value Expansion. A family of MBRL. @@ -110,7 +110,7 @@ def true_train_policy(self, states, actions, rewards, next_states, dones): next_states, next_actions ) target_q_values = ( - torch.minimum(target_q_one, target_q_two) - self.alpha * next_log_pi + torch.minimum(target_q_one, target_q_two) - self._alpha * next_log_pi ) q_target = rewards + self.gamma * (1 - dones) * target_q_values q_target = q_target.detach() @@ -129,7 +129,7 @@ def true_train_policy(self, states, actions, rewards, next_states, dones): pi, first_log_p, _ = self.actor_net.sample(states) qf1_pi, qf2_pi = self.critic_net(states, pi) min_qf_pi = torch.minimum(qf1_pi, qf2_pi) - actor_loss = ((self.alpha * first_log_p) - min_qf_pi).mean() + actor_loss = ((self._alpha * first_log_p) - min_qf_pi).mean() # Update the Actor self.actor_net_optimiser.zero_grad() @@ -225,7 +225,7 @@ def train_policy(self, experiences): assert len(rewards.shape) == 2 and rewards.shape[1] == 1 assert len(next_states.shape) >= 2 # Step 2 train as usual - self.true_train_policy( + self._true_train_policy( states=states, actions=actions, rewards=rewards, @@ -233,9 +233,9 @@ def train_policy(self, experiences): dones=dones, ) # # # Step 3 Dyna add more data - self.dyna_generate_and_train(next_states=next_states) + self._dyna_generate_and_train(next_states=next_states) - def dyna_generate_and_train(self, next_states): + def _dyna_generate_and_train(self, next_states): """ Only off-policy Dyna will work. :param next_states: @@ -266,7 +266,7 @@ def dyna_generate_and_train(self, next_states): # Pay attention to here! It is dones in the Cares RL Code! pred_dones = torch.FloatTensor(np.zeros(pred_rs.shape)).to(self.device) # states, actions, rewards, next_states, not_dones - self.true_train_policy( + self._true_train_policy( pred_states, pred_actions, pred_rs, pred_n_states, pred_dones ) diff --git a/cares_reinforcement_learning/algorithm/mbrl/__init__.py b/cares_reinforcement_learning/algorithm/mbrl/__init__.py index 4725967b..fff35c2e 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/__init__.py +++ b/cares_reinforcement_learning/algorithm/mbrl/__init__.py @@ -1 +1 @@ -from .DYNA_SAC import MBRL_DYNA_SAC +from .DYNA_SAC import DynaSAC \ No newline at end of file diff --git a/cares_reinforcement_learning/networks/World_Models/__init__.py b/cares_reinforcement_learning/networks/World_Models/__init__.py deleted file mode 100644 index 31e91b0a..00000000 --- a/cares_reinforcement_learning/networks/World_Models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cares_reinforcement_learning.networks.World_Models.ensemble_integrated import ( - Ensemble_World_Reward, -) diff --git a/cares_reinforcement_learning/networks/world_models/__init__.py b/cares_reinforcement_learning/networks/world_models/__init__.py new file mode 100644 index 00000000..7a4356a7 --- /dev/null +++ b/cares_reinforcement_learning/networks/world_models/__init__.py @@ -0,0 +1,3 @@ +from cares_reinforcement_learning.networks.world_models.ensemble_integrated import ( + EnsembleWorldReward, +) diff --git a/cares_reinforcement_learning/networks/World_Models/ensemble_integrated.py b/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py similarity index 97% rename from cares_reinforcement_learning/networks/World_Models/ensemble_integrated.py rename to cares_reinforcement_learning/networks/world_models/ensemble_integrated.py index bb5a2e90..4d6f2e10 100644 --- a/cares_reinforcement_learning/networks/World_Models/ensemble_integrated.py +++ b/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py @@ -8,11 +8,11 @@ from torch import optim import numpy as np from cares_reinforcement_learning.util.helpers import normalize_obs_deltas -from cares_reinforcement_learning.networks.World_Models.simple_dynamics import ( - Simple_Dynamics, +from cares_reinforcement_learning.networks.world_models.simple_dynamics import ( + SimpleDynamics, ) -from cares_reinforcement_learning.networks.World_Models.simple_rewards import ( - Simple_Reward, +from cares_reinforcement_learning.networks.world_models.simple_rewards import ( + SimpleReward, ) @@ -27,12 +27,12 @@ class IntegratedWorldModel: """ def __init__(self, observation_size, num_actions, hidden_size, lr=0.001): - self.dyna_network = Simple_Dynamics( + self.dyna_network = SimpleDynamics( observation_size=observation_size, num_actions=num_actions, hidden_size=hidden_size, ) - self.reward_network = Simple_Reward( + self.reward_network = SimpleReward( observation_size=observation_size, num_actions=num_actions, hidden_size=hidden_size, @@ -95,7 +95,7 @@ def train_overall(self, states, actions, next_states, next_actions, next_rewards self.all_optimizer.step() -class Ensemble_World_Reward: +class EnsembleWorldReward: """ Ensemble the integrated dynamic reward models. It works like a group of experts. The predicted results can be used to estimate the uncertainty. diff --git a/cares_reinforcement_learning/networks/World_Models/simple_dynamics.py b/cares_reinforcement_learning/networks/world_models/simple_dynamics.py similarity index 98% rename from cares_reinforcement_learning/networks/World_Models/simple_dynamics.py rename to cares_reinforcement_learning/networks/world_models/simple_dynamics.py index 8a05a9f9..7c27ceaf 100644 --- a/cares_reinforcement_learning/networks/World_Models/simple_dynamics.py +++ b/cares_reinforcement_learning/networks/world_models/simple_dynamics.py @@ -9,7 +9,7 @@ ) -class Simple_Dynamics(nn.Module): +class SimpleDynamics(nn.Module): """ A world model with fully connected layers. It takes current states (s) and current actions (a), and predict next states (s'). diff --git a/cares_reinforcement_learning/networks/World_Models/simple_rewards.py b/cares_reinforcement_learning/networks/world_models/simple_rewards.py similarity index 98% rename from cares_reinforcement_learning/networks/World_Models/simple_rewards.py rename to cares_reinforcement_learning/networks/world_models/simple_rewards.py index 90310d58..1426b47d 100644 --- a/cares_reinforcement_learning/networks/World_Models/simple_rewards.py +++ b/cares_reinforcement_learning/networks/world_models/simple_rewards.py @@ -4,7 +4,7 @@ from cares_reinforcement_learning.util.helpers import weight_init -class Simple_Reward(nn.Module): +class SimpleReward(nn.Module): def __init__(self, observation_size, num_actions, hidden_size): """ Note, This reward function is limited to 0 ~ 1 for dm_control. diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 31942cc0..3630ab04 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -84,13 +84,13 @@ def create_MBRL_DYNA(observation_size, action_num, config: MBRL_DYNAConfig): An extra world model is added. """ - from cares_reinforcement_learning.algorithm.mbrl import MBRL_DYNA_SAC + from cares_reinforcement_learning.algorithm.mbrl import DynaSAC from cares_reinforcement_learning.networks.SAC import Actor, Critic - from cares_reinforcement_learning.networks.World_Models import Ensemble_World_Reward + from cares_reinforcement_learning.networks.world_models import EnsembleWorldReward actor = Actor(observation_size, action_num) critic = Critic(observation_size, action_num) - world_model = Ensemble_World_Reward( + world_model = EnsembleWorldReward( observation_size=observation_size, num_actions=action_num, num_models=config.num_models, @@ -98,7 +98,7 @@ def create_MBRL_DYNA(observation_size, action_num, config: MBRL_DYNAConfig): ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - agent = MBRL_DYNA_SAC( + agent = DynaSAC( actor_network=actor, critic_network=critic, world_network=world_model,