Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vae + vector #202

Merged
merged 7 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 75 additions & 28 deletions cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,22 @@ def __init__(
]

def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0.1,
) -> np.ndarray:
self.actor.eval()
self.autoencoder.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state_tensor = {"image": image_tensor, "vector": vector_tensor}

action = self.actor(state_tensor)
action = action.cpu().data.numpy().flatten()
Expand All @@ -108,7 +116,7 @@ def select_action_from_policy(

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
Expand Down Expand Up @@ -145,7 +153,7 @@ def _update_autoencoder(self, states: torch.Tensor) -> float:
ae_loss = self.autoencoder.update_autoencoder(states)
return ae_loss.item()

def _update_actor(self, states: torch.Tensor) -> float:
def _update_actor(self, states: dict[str, torch.Tensor]) -> float:
actor_q_one, actor_q_two = self.critic(
states, self.actor(states, detach_encoder=True), detach_encoder=True
)
Expand Down Expand Up @@ -174,16 +182,19 @@ def _get_latent_state(
return latent_state

def _update_predictive_model(
self, states: np.ndarray, actions: np.ndarray, next_states: np.ndarray
self,
states: dict[str, torch.Tensor],
actions: np.ndarray,
next_states: dict[str, torch.Tensor],
) -> list[float]:

with torch.no_grad():
latent_state = self._get_latent_state(
states, detach_output=True, sample_latent=True
states["image"], detach_output=True, sample_latent=True
)

latent_next_state = self._get_latent_state(
next_states, detach_output=True, sample_latent=True
next_states["image"], detach_output=True, sample_latent=True
)

pred_losses = []
Expand Down Expand Up @@ -216,19 +227,41 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

batch_size = len(states)
states_images = [state["image"] for state in states]
states_vector = [state["vector"] for state in states]

next_states_images = [next_state["image"] for next_state in next_states]
next_states_vector = [next_state["vector"] for next_state in next_states]

batch_size = len(states_images)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device)
states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
states_images = states_images / 255

states = {"image": states_images, "vector": states_vector}

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)
dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Normalise states and next_states
next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to(
self.device
)
next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to(
self.device
)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
states /= 255
next_states /= 255
next_states_images = next_states_images / 255

next_states = {"image": next_states_images, "vector": next_states_vector}

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
Expand All @@ -245,7 +278,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
info["critic_loss_total"] = critic_loss_total

# Update Autoencoder
ae_loss = self._update_autoencoder(states)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

if self.learn_counter % self.policy_update_freq == 0:
Expand Down Expand Up @@ -322,26 +355,40 @@ def _get_novelty_rate(self, state_tensor_img: torch.Tensor) -> float:
return novelty_rate

def get_intrinsic_reward(
self, state: np.ndarray, action: np.ndarray, next_state: np.ndarray
self,
state: dict[str, np.ndarray],
action: np.ndarray,
next_state: dict[str, np.ndarray],
) -> float:
with torch.no_grad():
# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state_tensor = {"image": image_tensor, "vector": vector_tensor}

next_vector_tensor = torch.FloatTensor(next_state["vector"])
next_vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

next_image_tensor = torch.FloatTensor(next_state["image"])
next_image_tensor = next_image_tensor.unsqueeze(0).to(self.device)
next_image_tensor = next_image_tensor / 255

next_state_tensor = torch.FloatTensor(next_state).to(self.device)
next_state_tensor = next_state_tensor.unsqueeze(0)
next_state_tensor = next_state_tensor / 255
next_state_tensor = {
"image": next_image_tensor,
"vector": next_vector_tensor,
}

action_tensor = torch.FloatTensor(action).to(self.device)
action_tensor = action_tensor.unsqueeze(0)

surprise_rate = self._get_surprise_rate(
state_tensor, action_tensor, next_state_tensor
state_tensor["image"], action_tensor, next_state_tensor["image"]
)
novelty_rate = self._get_novelty_rate(state_tensor)
novelty_rate = self._get_novelty_rate(state_tensor["image"])

# TODO make these parameters - i.e. Tony's work
a = 1.0
Expand Down
69 changes: 51 additions & 18 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import SACAEConfig
Expand Down Expand Up @@ -97,14 +96,23 @@ def __init__(

# pylint: disable-next=unused-argument
def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0,
) -> np.ndarray:
# note that when evaluating this algorithm we need to select mu as action
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state)
state_tensor = state_tensor.unsqueeze(0).to(self.device)
state_tensor = state_tensor / 255

vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state_tensor = {"image": image_tensor, "vector": vector_tensor}

if evaluation:
(_, _, action) = self.actor_net(state_tensor)
Expand All @@ -120,12 +128,13 @@ def alpha(self) -> torch.Tensor:

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> tuple[float, float, float]:

with torch.no_grad():
next_actions, next_log_pi, _ = self.actor_net(next_states)

Expand Down Expand Up @@ -153,7 +162,9 @@ def _update_critic(

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]:
def _update_actor_alpha(
self, states: dict[str, torch.Tensor]
) -> tuple[float, float]:
pi, log_pi, _ = self.actor_net(states, detach_encoder=True)
qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True)

Expand Down Expand Up @@ -197,37 +208,59 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

batch_size = len(states)
states_images = [state["image"] for state in states]
states_vector = [state["vector"] for state in states]

next_states_images = [next_state["image"] for next_state in next_states]
next_states_vector = [next_state["vector"] for next_state in next_states]

batch_size = len(states_images)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device)
states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
states_images = states_images / 255

states = {"image": states_images, "vector": states_vector}

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)

next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to(
self.device
)
next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to(
self.device
)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
next_states_images = next_states_images / 255

next_states = {"image": next_states_images, "vector": next_states_vector}

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size x whatever
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states_normalised = states / 255
next_states_normalised = next_states / 255

info = {}

# Update the Critic
critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states_normalised, actions, rewards, next_states_normalised, dones
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

# Update the Actor
if self.learn_counter % self.policy_update_freq == 0:
actor_loss, alpha_loss = self._update_actor_alpha(states_normalised)
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()
Expand All @@ -247,7 +280,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)

if self.learn_counter % self.decoder_update_freq == 0:
ae_loss = self._update_autoencoder(states_normalised)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

return info
Expand Down
Loading
Loading