diff --git a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py index 527aa511..b439aff0 100644 --- a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py @@ -84,14 +84,15 @@ def __init__( ] def select_action_from_policy( - self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1 + self, + state: dict[str, np.ndarray], + evaluation: bool = False, + noise_scale: float = 0.1, ) -> np.ndarray: self.actor.eval() self.autoencoder.eval() with torch.no_grad(): - state_tensor = torch.FloatTensor(state).to(self.device) - state_tensor = state_tensor.unsqueeze(0) - state_tensor = state_tensor / 255 + state_tensor = hlp.image_state_dict_to_tensor(state, self.device) action = self.actor(state_tensor) action = action.cpu().data.numpy().flatten() @@ -108,7 +109,7 @@ def select_action_from_policy( def _update_critic( self, - states: torch.Tensor, + states: dict[str, torch.Tensor], actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, @@ -145,7 +146,7 @@ def _update_autoencoder(self, states: torch.Tensor) -> float: ae_loss = self.autoencoder.update_autoencoder(states) return ae_loss.item() - def _update_actor(self, states: torch.Tensor) -> float: + def _update_actor(self, states: dict[str, torch.Tensor]) -> float: actor_q_one, actor_q_two = self.critic( states, self.actor(states, detach_encoder=True), detach_encoder=True ) @@ -174,16 +175,19 @@ def _get_latent_state( return latent_state def _update_predictive_model( - self, states: np.ndarray, actions: np.ndarray, next_states: np.ndarray + self, + states: dict[str, torch.Tensor], + actions: np.ndarray, + next_states: dict[str, torch.Tensor], ) -> list[float]: with torch.no_grad(): latent_state = self._get_latent_state( - states, detach_output=True, sample_latent=True + states["image"], detach_output=True, sample_latent=True ) latent_next_state = self._get_latent_state( - next_states, detach_output=True, sample_latent=True + next_states["image"], detach_output=True, sample_latent=True ) pred_losses = [] @@ -218,17 +222,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: batch_size = len(states) - # Convert into tensor - states = torch.FloatTensor(np.asarray(states)).to(self.device) + states = hlp.image_states_dict_to_tensor(states, self.device) + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) - next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) - dones = torch.LongTensor(np.asarray(dones)).to(self.device) - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - states /= 255 - next_states /= 255 + next_states = hlp.image_states_dict_to_tensor(next_states, self.device) + + dones = torch.LongTensor(np.asarray(dones)).to(self.device) # Reshape to batch_size rewards = rewards.unsqueeze(0).reshape(batch_size, 1) @@ -245,7 +246,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: info["critic_loss_total"] = critic_loss_total # Update Autoencoder - ae_loss = self._update_autoencoder(states) + ae_loss = self._update_autoencoder(states["image"]) info["ae_loss"] = ae_loss if self.learn_counter % self.policy_update_freq == 0: @@ -322,26 +323,40 @@ def _get_novelty_rate(self, state_tensor_img: torch.Tensor) -> float: return novelty_rate def get_intrinsic_reward( - self, state: np.ndarray, action: np.ndarray, next_state: np.ndarray + self, + state: dict[str, np.ndarray], + action: np.ndarray, + next_state: dict[str, np.ndarray], ) -> float: with torch.no_grad(): - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - state_tensor = torch.FloatTensor(state).to(self.device) - state_tensor = state_tensor.unsqueeze(0) - state_tensor = state_tensor / 255 + vector_tensor = torch.FloatTensor(state["vector"]) + vector_tensor = vector_tensor.unsqueeze(0).to(self.device) + + image_tensor = torch.FloatTensor(state["image"]) + image_tensor = image_tensor.unsqueeze(0).to(self.device) + image_tensor = image_tensor / 255 + + state_tensor = {"image": image_tensor, "vector": vector_tensor} + + next_vector_tensor = torch.FloatTensor(next_state["vector"]) + next_vector_tensor = vector_tensor.unsqueeze(0).to(self.device) + + next_image_tensor = torch.FloatTensor(next_state["image"]) + next_image_tensor = next_image_tensor.unsqueeze(0).to(self.device) + next_image_tensor = next_image_tensor / 255 - next_state_tensor = torch.FloatTensor(next_state).to(self.device) - next_state_tensor = next_state_tensor.unsqueeze(0) - next_state_tensor = next_state_tensor / 255 + next_state_tensor = { + "image": next_image_tensor, + "vector": next_vector_tensor, + } action_tensor = torch.FloatTensor(action).to(self.device) action_tensor = action_tensor.unsqueeze(0) surprise_rate = self._get_surprise_rate( - state_tensor, action_tensor, next_state_tensor + state_tensor["image"], action_tensor, next_state_tensor["image"] ) - novelty_rate = self._get_novelty_rate(state_tensor) + novelty_rate = self._get_novelty_rate(state_tensor["image"]) # TODO make these parameters - i.e. Tony's work a = 1.0 diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index a1081ed8..52a70d03 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -15,7 +15,6 @@ import torch.nn.functional as F import cares_reinforcement_learning.util.helpers as hlp -from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.memory import MemoryBuffer from cares_reinforcement_learning.util.configurations import SACAEConfig @@ -97,14 +96,15 @@ def __init__( # pylint: disable-next=unused-argument def select_action_from_policy( - self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0 + self, + state: dict[str, np.ndarray], + evaluation: bool = False, + noise_scale: float = 0, ) -> np.ndarray: # note that when evaluating this algorithm we need to select mu as action self.actor_net.eval() with torch.no_grad(): - state_tensor = torch.FloatTensor(state) - state_tensor = state_tensor.unsqueeze(0).to(self.device) - state_tensor = state_tensor / 255 + state_tensor = hlp.image_state_dict_to_tensor(state, self.device) if evaluation: (_, _, action) = self.actor_net(state_tensor) @@ -120,12 +120,13 @@ def alpha(self) -> torch.Tensor: def _update_critic( self, - states: torch.Tensor, + states: dict[str, torch.Tensor], actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor, ) -> tuple[float, float, float]: + with torch.no_grad(): next_actions, next_log_pi, _ = self.actor_net(next_states) @@ -153,7 +154,9 @@ def _update_critic( return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item() - def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]: + def _update_actor_alpha( + self, states: dict[str, torch.Tensor] + ) -> tuple[float, float]: pi, log_pi, _ = self.actor_net(states, detach_encoder=True) qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True) @@ -199,27 +202,24 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: batch_size = len(states) - # Convert into tensor - states = torch.FloatTensor(np.asarray(states)).to(self.device) + states = hlp.image_states_dict_to_tensor(states, self.device) + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) - next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) + + next_states = hlp.image_states_dict_to_tensor(next_states, self.device) + dones = torch.LongTensor(np.asarray(dones)).to(self.device) # Reshape to batch_size x whatever rewards = rewards.unsqueeze(0).reshape(batch_size, 1) dones = dones.unsqueeze(0).reshape(batch_size, 1) - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - states_normalised = states / 255 - next_states_normalised = next_states / 255 - info = {} # Update the Critic critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic( - states_normalised, actions, rewards, next_states_normalised, dones + states, actions, rewards, next_states, dones ) info["critic_loss_one"] = critic_loss_one info["critic_loss_two"] = critic_loss_two @@ -227,7 +227,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: # Update the Actor if self.learn_counter % self.policy_update_freq == 0: - actor_loss, alpha_loss = self._update_actor_alpha(states_normalised) + actor_loss, alpha_loss = self._update_actor_alpha(states) info["actor_loss"] = actor_loss info["alpha_loss"] = alpha_loss info["alpha"] = self.alpha.item() @@ -247,7 +247,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) if self.learn_counter % self.decoder_update_freq == 0: - ae_loss = self._update_autoencoder(states_normalised) + ae_loss = self._update_autoencoder(states["image"]) info["ae_loss"] = ae_loss return info diff --git a/cares_reinforcement_learning/algorithm/policy/TD3AE.py b/cares_reinforcement_learning/algorithm/policy/TD3AE.py index 10d5b544..5511d3a2 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3AE.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3AE.py @@ -13,7 +13,6 @@ import torch.nn.functional as F import cares_reinforcement_learning.util.helpers as hlp -from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.memory import MemoryBuffer from cares_reinforcement_learning.util.configurations import TD3AEConfig @@ -78,13 +77,14 @@ def __init__( ) def select_action_from_policy( - self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1 + self, + state: dict[str, np.ndarray], + evaluation: bool = False, + noise_scale: float = 0.1, ) -> np.ndarray: self.actor_net.eval() with torch.no_grad(): - state_tensor = torch.FloatTensor(state).to(self.device) - state_tensor = state_tensor.unsqueeze(0) - state_tensor = state_tensor / 255 + state_tensor = hlp.image_state_dict_to_tensor(state, self.device) action = self.actor_net(state_tensor) action = action.cpu().data.numpy().flatten() @@ -98,7 +98,7 @@ def select_action_from_policy( def _update_critic( self, - states: torch.Tensor, + states: dict[str, torch.Tensor], actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, @@ -132,7 +132,7 @@ def _update_critic( return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item() - def _update_actor(self, states: torch.Tensor) -> float: + def _update_actor(self, states: dict[str, torch.Tensor]) -> float: actions = self.actor_net(states, detach_encoder=True) actor_q_values, _ = self.critic_net(states, actions, detach_encoder=True) actor_loss = -actor_q_values.mean() @@ -169,26 +169,23 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: batch_size = len(states) - # Convert into tensor - states = torch.FloatTensor(np.asarray(states)).to(self.device) + states = hlp.image_states_dict_to_tensor(states, self.device) + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) - next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) + + next_states = hlp.image_states_dict_to_tensor(next_states, self.device) + dones = torch.LongTensor(np.asarray(dones)).to(self.device) # Reshape to batch_size rewards = rewards.unsqueeze(0).reshape(batch_size, 1) dones = dones.unsqueeze(0).reshape(batch_size, 1) - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - states_normalised = states / 255 - next_states_normalised = next_states / 255 - info = {} critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic( - states_normalised, actions, rewards, next_states_normalised, dones + states, actions, rewards, next_states, dones ) info["critic_loss_one"] = critic_loss_one info["critic_loss_two"] = critic_loss_two @@ -196,7 +193,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: if self.learn_counter % self.policy_update_freq == 0: # Update Actor - actor_loss = self._update_actor(states_normalised) + actor_loss = self._update_actor(states) info["actor_loss"] = actor_loss # Update target network params @@ -222,7 +219,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) if self.learn_counter % self.decoder_update_freq == 0: - ae_loss = self._update_autoencoder(states_normalised) + ae_loss = self._update_autoencoder(states["image"]) info["ae_loss"] = ae_loss return info diff --git a/cares_reinforcement_learning/encoders/configurations.py b/cares_reinforcement_learning/encoders/configurations.py index cb12a8f4..e7a56f82 100644 --- a/cares_reinforcement_learning/encoders/configurations.py +++ b/cares_reinforcement_learning/encoders/configurations.py @@ -50,29 +50,6 @@ class VanillaAEConfig(AEConfig): latent_lambda: float = 1e-6 -# sqVAE = parser.add_argument_group('SQ-VAE specific parameters') -# sqVAE.add_argument('--dim_z', type=int, default=16) -# sqVAE.add_argument('--size_dict', type=int, default=512) -# sqVAE.add_argument('--param_var_q', type=str, default=ParamVarQ.GAUSSIAN_1.value, -# choices=[pvq.value for pvq in ParamVarQ]) -# sqVAE.add_argument('--num_rb', type=int, default=6) -# sqVAE.add_argument('--flg_arelbo', type=bool, default=True) -# sqVAE.add_argument('--log_param_q_init', type=float, default=0.0) -# sqVAE.add_argument('--temperature_init', type=float, default=1.0) - -# class SQVAEConfig(AEConfig): -# """ -# Configuration class for the sqvae autoencoder. - -# Attributes: - -# """ - -# type: str = "sqvae" -# flg_arelbo: bool = Field(description="Flag to use arelbo loss function") -# loss_latent: str = Field(description="") - - class BurgessConfig(AEConfig): """ Configuration class for the Burgess autoencoder. diff --git a/cares_reinforcement_learning/networks/NaSATD3/actor.py b/cares_reinforcement_learning/networks/NaSATD3/actor.py index 4f9ba098..ac34a17f 100644 --- a/cares_reinforcement_learning/networks/NaSATD3/actor.py +++ b/cares_reinforcement_learning/networks/NaSATD3/actor.py @@ -11,6 +11,7 @@ class Actor(nn.Module): def __init__( self, + vector_observation_size: int, num_actions: int, autoencoder: Autoencoder, hidden_size: list[int] = None, @@ -22,9 +23,13 @@ def __init__( self.num_actions = num_actions self.autoencoder = autoencoder self.hidden_size = hidden_size + self.vector_observation_size = vector_observation_size self.act_net = nn.Sequential( - nn.Linear(self.autoencoder.latent_dim, self.hidden_size[0]), + nn.Linear( + self.autoencoder.latent_dim + self.vector_observation_size, + self.hidden_size[0], + ), nn.ReLU(), nn.Linear(self.hidden_size[0], self.hidden_size[1]), nn.ReLU(), @@ -34,16 +39,21 @@ def __init__( self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, detach_encoder: bool = False + self, state: dict[str, torch.Tensor], detach_encoder: bool = False ) -> torch.Tensor: # NaSATD3 detatches the encoder at the output if self.autoencoder.ae_type == Autoencoders.BURGESS: # take the mean value for stability z_vector, _, _ = self.autoencoder.encoder( - state, detach_output=detach_encoder + state["image"], detach_output=detach_encoder ) else: - z_vector = self.autoencoder.encoder(state, detach_output=detach_encoder) + z_vector = self.autoencoder.encoder( + state["image"], detach_output=detach_encoder + ) + + actor_input = z_vector + if self.vector_observation_size > 0: + actor_input = torch.cat([state["vector"], actor_input], dim=1) - output = self.act_net(z_vector) - return output + return self.act_net(actor_input) diff --git a/cares_reinforcement_learning/networks/NaSATD3/critic.py b/cares_reinforcement_learning/networks/NaSATD3/critic.py index ffd09d13..d4934999 100644 --- a/cares_reinforcement_learning/networks/NaSATD3/critic.py +++ b/cares_reinforcement_learning/networks/NaSATD3/critic.py @@ -11,6 +11,7 @@ class Critic(nn.Module): def __init__( self, + vector_observation_size: int, num_actions: int, autoencoder: Autoencoder, hidden_size: list[int] = None, @@ -21,10 +22,16 @@ def __init__( self.autoencoder = autoencoder self.hidden_size = hidden_size + self.vector_observation_size = vector_observation_size # pylint: disable-next=invalid-name self.Q1 = nn.Sequential( - nn.Linear(self.autoencoder.latent_dim + num_actions, self.hidden_size[0]), + nn.Linear( + self.autoencoder.latent_dim + + num_actions + + self.vector_observation_size, + self.hidden_size[0], + ), nn.ReLU(), nn.Linear(self.hidden_size[0], self.hidden_size[1]), nn.ReLU(), @@ -33,7 +40,12 @@ def __init__( # pylint: disable-next=invalid-name self.Q2 = nn.Sequential( - nn.Linear(self.autoencoder.latent_dim + num_actions, self.hidden_size[0]), + nn.Linear( + self.autoencoder.latent_dim + + num_actions + + self.vector_observation_size, + self.hidden_size[0], + ), nn.ReLU(), nn.Linear(self.hidden_size[0], self.hidden_size[1]), nn.ReLU(), @@ -43,18 +55,27 @@ def __init__( self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + detach_encoder: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # NaSATD3 detatches the encoder at the output if self.autoencoder.ae_type == Autoencoders.BURGESS: # take the mean value for stability z_vector, _, _ = self.autoencoder.encoder( - state, detach_output=detach_encoder + state["image"], detach_output=detach_encoder ) else: - z_vector = self.autoencoder.encoder(state, detach_output=detach_encoder) + z_vector = self.autoencoder.encoder( + state["image"], detach_output=detach_encoder + ) + + critic_input = z_vector + if self.vector_observation_size > 0: + critic_input = torch.cat([state["vector"], critic_input], dim=1) - obs_action = torch.cat([z_vector, action], dim=1) + obs_action = torch.cat([critic_input, action], dim=1) q1 = self.Q1(obs_action) q2 = self.Q2(obs_action) return q1, q2 diff --git a/cares_reinforcement_learning/networks/SACAE/actor.py b/cares_reinforcement_learning/networks/SACAE/actor.py index a9180442..46928189 100644 --- a/cares_reinforcement_learning/networks/SACAE/actor.py +++ b/cares_reinforcement_learning/networks/SACAE/actor.py @@ -8,6 +8,7 @@ class Actor(SACActor): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -18,15 +19,27 @@ def __init__( if log_std_bounds is None: log_std_bounds = [-10, 2] - super().__init__(encoder.latent_dim, num_actions, hidden_size, log_std_bounds) + super().__init__( + encoder.latent_dim + vector_observation_size, + num_actions, + hidden_size, + log_std_bounds, + ) self.encoder = encoder + self.vector_observation_size = vector_observation_size + self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, detach_encoder: bool = False + self, state: dict[str, torch.Tensor], detach_encoder: bool = False ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + actor_input = state_latent + if self.vector_observation_size > 0: + actor_input = torch.cat([state["vector"], actor_input], dim=1) + + return super().forward(actor_input) diff --git a/cares_reinforcement_learning/networks/SACAE/critic.py b/cares_reinforcement_learning/networks/SACAE/critic.py index f42358b3..5d9e59c4 100644 --- a/cares_reinforcement_learning/networks/SACAE/critic.py +++ b/cares_reinforcement_learning/networks/SACAE/critic.py @@ -8,6 +8,7 @@ class Critic(SACCritic): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,27 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation_size, num_actions, hidden_size + ) + + self.vector_observation_size = vector_observation_size self.encoder = encoder self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + detach_encoder: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent, action) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + critic_input = state_latent + if self.vector_observation_size > 0: + critic_input = torch.cat([state["vector"], critic_input], dim=1) + + return super().forward(critic_input, action) diff --git a/cares_reinforcement_learning/networks/TD3AE/actor.py b/cares_reinforcement_learning/networks/TD3AE/actor.py index 24c9bb11..75ddf3af 100644 --- a/cares_reinforcement_learning/networks/TD3AE/actor.py +++ b/cares_reinforcement_learning/networks/TD3AE/actor.py @@ -8,6 +8,7 @@ class Actor(TD3Actor): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,24 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation_size, num_actions, hidden_size + ) self.encoder = encoder self.apply(hlp.weight_init) + self.vector_observation_size = vector_observation_size + def forward( - self, state: torch.Tensor, detach_encoder: bool = False + self, state: dict[str, torch.Tensor], detach_encoder: bool = False ) -> torch.Tensor: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + actor_input = state_latent + if self.vector_observation_size > 0: + actor_input = torch.cat([state["vector"], actor_input], dim=1) + + return super().forward(actor_input) diff --git a/cares_reinforcement_learning/networks/TD3AE/critic.py b/cares_reinforcement_learning/networks/TD3AE/critic.py index f9e62ca2..7401def9 100644 --- a/cares_reinforcement_learning/networks/TD3AE/critic.py +++ b/cares_reinforcement_learning/networks/TD3AE/critic.py @@ -8,6 +8,7 @@ class Critic(TD3Critic): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,27 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation_size, num_actions, hidden_size + ) + + self.vector_observation_size = vector_observation_size self.encoder = encoder self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + detach_encoder: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent, action) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + critic_input = state_latent + if self.vector_observation_size > 0: + critic_input = torch.cat([state["vector"], critic_input], dim=1) + + return super().forward(critic_input, action) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index 3c23091b..5a537684 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -4,8 +4,8 @@ from cares_reinforcement_learning.encoders.configurations import ( AEConfig, - VanillaAEConfig, VAEConfig, + VanillaAEConfig, ) # NOTE: If a parameter is a list then don't wrap with Optional leave as implicit optional - List[type] = default @@ -40,7 +40,7 @@ class TrainingConfig(SubscriptableClass): class AlgorithmConfig(SubscriptableClass): - """f + """ Configuration class for the algorithm. These attributes are common to all algorithms. They can be overridden by the specific algorithm configuration. @@ -81,9 +81,6 @@ class AlgorithmConfig(SubscriptableClass): hidden_size: List[int] = None - # Determines how much prioritization is used, α = 0 corresponding to the uniform case - # per_alpha - class DQNConfig(AlgorithmConfig): algorithm: str = Field("DQN", Literal=True) @@ -159,6 +156,8 @@ class TD3AEConfig(AlgorithmConfig): encoder_tau: Optional[float] = 0.05 decoder_update_freq: Optional[int] = 1 + vector_observation: Optional[int] = 0 + autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig( latent_dim=50, num_layers=4, @@ -202,6 +201,8 @@ class SACAEConfig(AlgorithmConfig): encoder_tau: Optional[float] = 0.05 decoder_update_freq: Optional[int] = 1 + vector_observation: Optional[int] = 0 + autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig( latent_dim=50, num_layers=4, @@ -267,6 +268,8 @@ class NaSATD3Config(AlgorithmConfig): intrinsic_on: Optional[int] = 1 + vector_observation: Optional[int] = 0 + autoencoder_config: Optional[AEConfig] = VanillaAEConfig( latent_dim=200, num_layers=4, diff --git a/cares_reinforcement_learning/util/helpers.py b/cares_reinforcement_learning/util/helpers.py index 3b40055b..81881673 100644 --- a/cares_reinforcement_learning/util/helpers.py +++ b/cares_reinforcement_learning/util/helpers.py @@ -19,6 +19,36 @@ def get_device() -> torch.device: return device +def image_state_dict_to_tensor( + state: dict[str, np.ndarray], device: str +) -> dict[str, torch.Tensor]: + vector_tensor = torch.FloatTensor(state["vector"]) + vector_tensor = vector_tensor.unsqueeze(0).to(device) + + image_tensor = torch.FloatTensor(state["image"]) + image_tensor = image_tensor.unsqueeze(0).to(device) + image_tensor = image_tensor / 255 + + return {"image": image_tensor, "vector": vector_tensor} + + +def image_states_dict_to_tensor( + states: list[dict[str, np.ndarray]], device: str +) -> dict[str, torch.Tensor]: + states_images = [state["image"] for state in states] + states_vector = [state["vector"] for state in states] + + # Convert into tensor + states_images = torch.FloatTensor(np.asarray(states_images)).to(device) + states_vector = torch.FloatTensor(np.asarray(states_vector)).to(device) + + # Normalise states and next_states - image portion + # This because the states are [0-255] and the predictions are [0-1] + states_images = states_images / 255 + + return {"image": states_images, "vector": states_vector} + + def create_path_from_format_string( format_str: str, algorithm: str, diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 6c42aaba..a5371801 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -141,19 +141,29 @@ def create_SACAE(observation_size, action_num, config: AlgorithmConfig): ae_factory = AEFactory() autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size, config=config.autoencoder_config + observation_size=observation_size["image"], config=config.autoencoder_config ) actor_encoder = copy.deepcopy(autoencoder.encoder) critic_encoder = copy.deepcopy(autoencoder.encoder) + vector_observation_size = ( + observation_size["vector"] if config.vector_observation else 0 + ) + actor = Actor( + vector_observation_size, actor_encoder, action_num, hidden_size=config.hidden_size, log_std_bounds=config.log_std_bounds, ) - critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size) + critic = Critic( + vector_observation_size, + critic_encoder, + action_num, + hidden_size=config.hidden_size, + ) device = hlp.get_device() agent = SACAE( @@ -224,14 +234,28 @@ def create_TD3AE(observation_size, action_num, config: AlgorithmConfig): ae_factory = AEFactory() autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size, config=config.autoencoder_config + observation_size=observation_size["image"], config=config.autoencoder_config ) actor_encoder = copy.deepcopy(autoencoder.encoder) critic_encoder = copy.deepcopy(autoencoder.encoder) - actor = Actor(actor_encoder, action_num, hidden_size=config.hidden_size) - critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size) + vector_observation_size = ( + observation_size["vector"] if config.vector_observation else 0 + ) + + actor = Actor( + vector_observation_size, + actor_encoder, + action_num, + hidden_size=config.hidden_size, + ) + critic = Critic( + vector_observation_size, + critic_encoder, + action_num, + hidden_size=config.hidden_size, + ) device = hlp.get_device() agent = TD3AE( @@ -251,15 +275,21 @@ def create_NaSATD3(observation_size, action_num, config: AlgorithmConfig): ae_factory = AEFactory() autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size, config=config.autoencoder_config + observation_size=observation_size["image"], config=config.autoencoder_config + ) + + vector_observation_size = ( + observation_size["vector"] if config.vector_observation else 0 ) actor = Actor( + vector_observation_size, action_num, autoencoder, hidden_size=config.hidden_size, ) critic = Critic( + vector_observation_size, action_num, autoencoder, hidden_size=config.hidden_size, diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index dfa159b7..705f0c62 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -17,18 +17,27 @@ def _policy_buffer( image_state, add_log_prob=False, ): - state = ( - np.random.randint(255, size=observation_size, dtype=np.uint8) - if image_state - else list(range(observation_size)) - ) + if image_state: + state_vector = list(range(observation_size["vector"])) + state_image = np.random.randint( + 255, size=observation_size["image"], dtype=np.uint8 + ) + state = {"image": state_image, "vector": state_vector} + else: + state = list(range(observation_size)) + action = list(range(action_num)) reward = 10 - next_state = ( - np.random.randint(255, size=observation_size, dtype=np.uint8) - if image_state - else list(range(observation_size)) - ) + + if image_state: + next_state_vector = list(range(observation_size["vector"])) + next_state_image = np.random.randint( + 255, size=observation_size["image"], dtype=np.uint8 + ) + next_state = {"image": next_state_image, "vector": next_state_vector} + else: + next_state = list(range(observation_size)) + done = False for _ in range(capacity): @@ -41,18 +50,28 @@ def _policy_buffer( def _value_buffer(memory_buffer, capacity, observation_size, action_num, image_state): - state = ( - np.random.randint(255, size=observation_size, dtype=np.uint8) - if image_state - else list(range(observation_size)) - ) + + if image_state: + state_vector = list(range(observation_size["vector"])) + state_image = np.random.randint( + 255, size=observation_size["image"], dtype=np.uint8 + ) + state = {"image": state_image, "vector": state_vector} + else: + state = list(range(observation_size)) + action = randrange(action_num) reward = 10 - next_state = ( - np.random.randint(255, size=observation_size, dtype=np.uint8) - if image_state - else list(range(observation_size)) - ) + + if image_state: + next_state_vector = list(range(observation_size["vector"])) + next_state_image = np.random.randint( + 255, size=observation_size["image"], dtype=np.uint8 + ) + next_state = {"image": next_state_image, "vector": next_state_vector} + else: + next_state = list(range(observation_size)) + done = False for _ in range(capacity): @@ -82,12 +101,13 @@ def test_algorithms(): action_num = 2 for algorithm, alg_config in algorithm_configurations.items(): + alg_config = alg_config() memory_buffer = memory_factory.create_memory(alg_config) observation_size = ( - observation_size_image + {"image": observation_size_image, "vector": observation_size_vector} if alg_config.image_observation else observation_size_vector ) @@ -121,3 +141,17 @@ def test_algorithms(): assert isinstance( info, dict ), f"{algorithm} did not return a dictionary of training info" + + intrinsic_on = ( + bool(alg_config.intrinsic_on) + if hasattr(alg_config, "intrinsic_on") + else False + ) + + if intrinsic_on: + experiences = memory_buffer.sample_uniform(1) + states, actions, _, next_states, _, _ = experiences + + intrinsic_reward = agent.get_intrinsic_reward( + states[0], actions[0], next_states[0] + )