Skip to content

Commit

Permalink
vae + vector (#202)
Browse files Browse the repository at this point in the history
* Added vector_observation configuration to enable concatenation of vector data to the latent output of the AE of image-based methods

* Image-based methods now take a dict[str, np.ndarray] with the expectation of keys image and vector. image_wrapper in gym_env has been updated to reflect this requirement. 

* SACAE image + vector

* TD3AE vector + image

* NaSATD3 + vector input.

* helper functions added to handle transformation of the dict to required tensors.
  • Loading branch information
beardyFace authored Oct 9, 2024
1 parent dbbb82e commit 01a1ea7
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 148 deletions.
73 changes: 44 additions & 29 deletions cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ def __init__(
]

def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0.1,
) -> np.ndarray:
self.actor.eval()
self.autoencoder.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
state_tensor = hlp.image_state_dict_to_tensor(state, self.device)

action = self.actor(state_tensor)
action = action.cpu().data.numpy().flatten()
Expand All @@ -108,7 +109,7 @@ def select_action_from_policy(

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
Expand Down Expand Up @@ -145,7 +146,7 @@ def _update_autoencoder(self, states: torch.Tensor) -> float:
ae_loss = self.autoencoder.update_autoencoder(states)
return ae_loss.item()

def _update_actor(self, states: torch.Tensor) -> float:
def _update_actor(self, states: dict[str, torch.Tensor]) -> float:
actor_q_one, actor_q_two = self.critic(
states, self.actor(states, detach_encoder=True), detach_encoder=True
)
Expand Down Expand Up @@ -174,16 +175,19 @@ def _get_latent_state(
return latent_state

def _update_predictive_model(
self, states: np.ndarray, actions: np.ndarray, next_states: np.ndarray
self,
states: dict[str, torch.Tensor],
actions: np.ndarray,
next_states: dict[str, torch.Tensor],
) -> list[float]:

with torch.no_grad():
latent_state = self._get_latent_state(
states, detach_output=True, sample_latent=True
states["image"], detach_output=True, sample_latent=True
)

latent_next_state = self._get_latent_state(
next_states, detach_output=True, sample_latent=True
next_states["image"], detach_output=True, sample_latent=True
)

pred_losses = []
Expand Down Expand Up @@ -218,17 +222,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:

batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states = hlp.image_states_dict_to_tensor(states, self.device)

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states /= 255
next_states /= 255
next_states = hlp.image_states_dict_to_tensor(next_states, self.device)

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
Expand All @@ -245,7 +246,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
info["critic_loss_total"] = critic_loss_total

# Update Autoencoder
ae_loss = self._update_autoencoder(states)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

if self.learn_counter % self.policy_update_freq == 0:
Expand Down Expand Up @@ -322,26 +323,40 @@ def _get_novelty_rate(self, state_tensor_img: torch.Tensor) -> float:
return novelty_rate

def get_intrinsic_reward(
self, state: np.ndarray, action: np.ndarray, next_state: np.ndarray
self,
state: dict[str, np.ndarray],
action: np.ndarray,
next_state: dict[str, np.ndarray],
) -> float:
with torch.no_grad():
# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state_tensor = {"image": image_tensor, "vector": vector_tensor}

next_vector_tensor = torch.FloatTensor(next_state["vector"])
next_vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

next_image_tensor = torch.FloatTensor(next_state["image"])
next_image_tensor = next_image_tensor.unsqueeze(0).to(self.device)
next_image_tensor = next_image_tensor / 255

next_state_tensor = torch.FloatTensor(next_state).to(self.device)
next_state_tensor = next_state_tensor.unsqueeze(0)
next_state_tensor = next_state_tensor / 255
next_state_tensor = {
"image": next_image_tensor,
"vector": next_vector_tensor,
}

action_tensor = torch.FloatTensor(action).to(self.device)
action_tensor = action_tensor.unsqueeze(0)

surprise_rate = self._get_surprise_rate(
state_tensor, action_tensor, next_state_tensor
state_tensor["image"], action_tensor, next_state_tensor["image"]
)
novelty_rate = self._get_novelty_rate(state_tensor)
novelty_rate = self._get_novelty_rate(state_tensor["image"])

# TODO make these parameters - i.e. Tony's work
a = 1.0
Expand Down
36 changes: 18 additions & 18 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import SACAEConfig
Expand Down Expand Up @@ -97,14 +96,15 @@ def __init__(

# pylint: disable-next=unused-argument
def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0,
) -> np.ndarray:
# note that when evaluating this algorithm we need to select mu as action
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state)
state_tensor = state_tensor.unsqueeze(0).to(self.device)
state_tensor = state_tensor / 255
state_tensor = hlp.image_state_dict_to_tensor(state, self.device)

if evaluation:
(_, _, action) = self.actor_net(state_tensor)
Expand All @@ -120,12 +120,13 @@ def alpha(self) -> torch.Tensor:

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> tuple[float, float, float]:

with torch.no_grad():
next_actions, next_log_pi, _ = self.actor_net(next_states)

Expand Down Expand Up @@ -153,7 +154,9 @@ def _update_critic(

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]:
def _update_actor_alpha(
self, states: dict[str, torch.Tensor]
) -> tuple[float, float]:
pi, log_pi, _ = self.actor_net(states, detach_encoder=True)
qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True)

Expand Down Expand Up @@ -199,35 +202,32 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:

batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states = hlp.image_states_dict_to_tensor(states, self.device)

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)

next_states = hlp.image_states_dict_to_tensor(next_states, self.device)

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size x whatever
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states_normalised = states / 255
next_states_normalised = next_states / 255

info = {}

# Update the Critic
critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states_normalised, actions, rewards, next_states_normalised, dones
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

# Update the Actor
if self.learn_counter % self.policy_update_freq == 0:
actor_loss, alpha_loss = self._update_actor_alpha(states_normalised)
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()
Expand All @@ -247,7 +247,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)

if self.learn_counter % self.decoder_update_freq == 0:
ae_loss = self._update_autoencoder(states_normalised)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

return info
Expand Down
33 changes: 15 additions & 18 deletions cares_reinforcement_learning/algorithm/policy/TD3AE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import TD3AEConfig
Expand Down Expand Up @@ -78,13 +77,14 @@ def __init__(
)

def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0.1,
) -> np.ndarray:
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
state_tensor = hlp.image_state_dict_to_tensor(state, self.device)

action = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
Expand All @@ -98,7 +98,7 @@ def select_action_from_policy(

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
Expand Down Expand Up @@ -132,7 +132,7 @@ def _update_critic(

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor(self, states: torch.Tensor) -> float:
def _update_actor(self, states: dict[str, torch.Tensor]) -> float:
actions = self.actor_net(states, detach_encoder=True)
actor_q_values, _ = self.critic_net(states, actions, detach_encoder=True)
actor_loss = -actor_q_values.mean()
Expand Down Expand Up @@ -169,34 +169,31 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:

batch_size = len(states)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states = hlp.image_states_dict_to_tensor(states, self.device)

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)

next_states = hlp.image_states_dict_to_tensor(next_states, self.device)

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states_normalised = states / 255
next_states_normalised = next_states / 255

info = {}

critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states_normalised, actions, rewards, next_states_normalised, dones
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

if self.learn_counter % self.policy_update_freq == 0:
# Update Actor
actor_loss = self._update_actor(states_normalised)
actor_loss = self._update_actor(states)
info["actor_loss"] = actor_loss

# Update target network params
Expand All @@ -222,7 +219,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)

if self.learn_counter % self.decoder_update_freq == 0:
ae_loss = self._update_autoencoder(states_normalised)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

return info
Expand Down
23 changes: 0 additions & 23 deletions cares_reinforcement_learning/encoders/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,6 @@ class VanillaAEConfig(AEConfig):
latent_lambda: float = 1e-6


# sqVAE = parser.add_argument_group('SQ-VAE specific parameters')
# sqVAE.add_argument('--dim_z', type=int, default=16)
# sqVAE.add_argument('--size_dict', type=int, default=512)
# sqVAE.add_argument('--param_var_q', type=str, default=ParamVarQ.GAUSSIAN_1.value,
# choices=[pvq.value for pvq in ParamVarQ])
# sqVAE.add_argument('--num_rb', type=int, default=6)
# sqVAE.add_argument('--flg_arelbo', type=bool, default=True)
# sqVAE.add_argument('--log_param_q_init', type=float, default=0.0)
# sqVAE.add_argument('--temperature_init', type=float, default=1.0)

# class SQVAEConfig(AEConfig):
# """
# Configuration class for the sqvae autoencoder.

# Attributes:

# """

# type: str = "sqvae"
# flg_arelbo: bool = Field(description="Flag to use arelbo loss function")
# loss_latent: str = Field(description="")


class BurgessConfig(AEConfig):
"""
Configuration class for the Burgess autoencoder.
Expand Down
Loading

0 comments on commit 01a1ea7

Please sign in to comment.