Skip to content

Commit

Permalink
SACAE and TD3AE updated to reflect default + base + actor/critic layo…
Browse files Browse the repository at this point in the history
…ut - lays groundwork for future generalisation of AE for all methods
  • Loading branch information
beardyFace committed Nov 25, 2024
1 parent 388ec91 commit d2679fc
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 116 deletions.
4 changes: 2 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
if self.learn_counter % self.target_update_freq == 0:
# Update the target networks - Soft Update
hlp.soft_update_params(
self.critic_net.Q1, self.target_critic_net.Q1, self.tau
self.critic_net.critic.Q1, self.target_critic_net.critic.Q1, self.tau
)
hlp.soft_update_params(
self.critic_net.Q2, self.target_critic_net.Q2, self.tau
self.critic_net.critic.Q2, self.target_critic_net.critic.Q2, self.tau
)
hlp.soft_update_params(
self.critic_net.encoder,
Expand Down
8 changes: 6 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/TD3AE.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:

# Update target network params
hlp.soft_update_params(
self.critic_net.Q1, self.target_critic_net.Q1, self.tau
self.critic_net.critic.Q1,
self.target_critic_net.critic.Q1,
self.tau,
)
hlp.soft_update_params(
self.critic_net.Q2, self.target_critic_net.Q2, self.tau
self.critic_net.critic.Q2,
self.target_critic_net.critic.Q2,
self.tau,
)

hlp.soft_update_params(
Expand Down
4 changes: 3 additions & 1 deletion cares_reinforcement_learning/networks/SAC/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
log_std_bounds = [-20.0, 2.0]
hidden_sizes = [256, 256]
if hidden_sizes is None:
hidden_sizes = [256, 256]

act_net = nn.Sequential(
nn.Linear(observation_size, hidden_sizes[0]),
Expand Down
10 changes: 8 additions & 2 deletions cares_reinforcement_learning/networks/SAC/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ def forward(

# Default network should have this architecture with hidden_sizes = [256, 256]:
class DefaultCritic(BaseCritic):
def __init__(self, observation_size: int, num_actions: int):
def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
input_size = observation_size + num_actions
hidden_sizes = [256, 256]
if hidden_sizes is None:
hidden_sizes = [256, 256]

# Q1 architecture
# pylint: disable-next=invalid-name
Expand Down
80 changes: 51 additions & 29 deletions cares_reinforcement_learning/networks/SACAE/actor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder
Expand All @@ -7,19 +8,23 @@
from cares_reinforcement_learning.util.configurations import SACAEConfig


# class BaseActor(DefaultSACActor):
# def __init__(self, observation_size: int, num_actions: int):
# pass


class BaseActor:
def __init__(self, encoder: Encoder, actor: SACActor | DefaultSACActor):
class BaseActor(nn.Module):
def __init__(
self,
num_actions: int,
encoder: Encoder,
actor: SACActor | DefaultSACActor,
add_vector_observation: bool = False,
):
super().__init__()

self.num_actions = num_actions
self.encoder = encoder
self.actor = actor

self.add_vector_observation = False
self.add_vector_observation = add_vector_observation

self.apply(hlp.weight_init)

def forward( # type: ignore
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
Expand All @@ -34,32 +39,49 @@ def forward( # type: ignore
return self.actor(actor_input)


class Actor(SACActor):
def __init__(
self,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
config: SACAEConfig,
):
class DefaultActor(BaseActor):
def __init__(self, observation_size: dict, num_actions: int):

encoder = Encoder(
observation_size["image"],
latent_dim=50,
num_layers=4,
num_filters=32,
kernel_size=3,
)

actor = DefaultSACActor(
encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024]
)

super().__init__(
encoder.latent_dim + vector_observation_size, num_actions, config
num_actions,
encoder,
actor,
)

self.encoder = encoder

self.vector_observation_size = vector_observation_size
class Actor(BaseActor):
def __init__(self, observation_size: dict, num_actions: int, config: SACAEConfig):

self.apply(hlp.weight_init)
ae_config = config.autoencoder_config
encoder = Encoder(
observation_size["image"],
latent_dim=ae_config.latent_dim,
num_layers=ae_config.num_layers,
num_filters=ae_config.num_filters,
kernel_size=ae_config.kernel_size,
)

def forward( # type: ignore
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)
actor_observation_size = encoder.latent_dim
if config.vector_observation:
actor_observation_size += observation_size["vector"]

actor_input = state_latent
if self.vector_observation_size > 0:
actor_input = torch.cat([state["vector"], actor_input], dim=1)
actor = SACActor(actor_observation_size, num_actions, config)

return super().forward(actor_input)
super().__init__(
num_actions,
encoder=encoder,
actor=actor,
add_vector_observation=bool(config.vector_observation),
)
67 changes: 55 additions & 12 deletions cares_reinforcement_learning/networks/SACAE/critic.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import torch
from torch import nn

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder
from cares_reinforcement_learning.networks.SAC import Critic as SACCritic
from cares_reinforcement_learning.networks.SAC import DefaultCritic as DefaultSACCritic
from cares_reinforcement_learning.util.configurations import SACAEConfig


class Critic(SACCritic):
class BaseCritic(nn.Module):
def __init__(
self,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
config: SACAEConfig,
critic: SACCritic | DefaultSACCritic,
add_vector_observation: bool = False,
):
super().__init__(
encoder.latent_dim + vector_observation_size, num_actions, config
)

self.vector_observation_size = vector_observation_size
super().__init__()

self.encoder = encoder
self.critic = critic

self.add_vector_observation = add_vector_observation

self.apply(hlp.weight_init)

def forward( # type: ignore
def forward(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
Expand All @@ -34,7 +34,50 @@ def forward( # type: ignore
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

critic_input = state_latent
if self.vector_observation_size > 0:
if self.add_vector_observation:
critic_input = torch.cat([state["vector"], critic_input], dim=1)

return super().forward(critic_input, action)
return self.critic(critic_input, action)


class DefaultCritic(BaseCritic):
def __init__(self, observation_size: dict, num_actions: int):

encoder = Encoder(
observation_size["image"],
latent_dim=50,
num_layers=4,
num_filters=32,
kernel_size=3,
)

critic = DefaultSACCritic(
encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024]
)

super().__init__(encoder, critic)


class Critic(BaseCritic):
def __init__(self, observation_size: dict, num_actions: int, config: SACAEConfig):

ae_config = config.autoencoder_config
encoder = Encoder(
observation_size["image"],
latent_dim=ae_config.latent_dim,
num_layers=ae_config.num_layers,
num_filters=ae_config.num_filters,
kernel_size=ae_config.kernel_size,
)

critic_observation_size = encoder.latent_dim
if config.vector_observation:
critic_observation_size += observation_size["vector"]

critic = SACCritic(critic_observation_size, num_actions, config)

super().__init__(
encoder=encoder,
critic=critic,
add_vector_observation=bool(config.vector_observation),
)
10 changes: 8 additions & 2 deletions cares_reinforcement_learning/networks/TD3/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ def forward(self, state: torch.Tensor) -> torch.Tensor:


class DefaultActor(BaseActor):
def __init__(self, observation_size: int, num_actions: int):
hidden_sizes = [256, 256]
def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
if hidden_sizes is None:
hidden_sizes = [256, 256]

act_net = nn.Sequential(
nn.Linear(observation_size, hidden_sizes[0]),
Expand Down
22 changes: 14 additions & 8 deletions cares_reinforcement_learning/networks/TD3/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,33 @@ def forward(

# This is the default base network for TD3 for reference and testing of default network configurations
class DefaultCritic(BaseCritic):
def __init__(self, observation_size: int, num_actions: int):
hidden_size = [256, 256]
def __init__(
self,
observation_size: int,
num_actions: int,
hidden_sizes: list[int] | None = None,
):
if hidden_sizes is None:
hidden_sizes = [256, 256]

# Q1 architecture
# pylint: disable-next=invalid-name
Q1 = nn.Sequential(
nn.Linear(observation_size + num_actions, hidden_size[0]),
nn.Linear(observation_size + num_actions, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_size[0], hidden_size[1]),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_size[1], 1),
nn.Linear(hidden_sizes[1], 1),
)

# Q2 architecture
# pylint: disable-next=invalid-name
Q2 = nn.Sequential(
nn.Linear(observation_size + num_actions, hidden_size[0]),
nn.Linear(observation_size + num_actions, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_size[0], hidden_size[1]),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_size[1], 1),
nn.Linear(hidden_sizes[1], 1),
)

super().__init__(Q1=Q1, Q2=Q2)
Expand Down
Loading

0 comments on commit d2679fc

Please sign in to comment.