From d2679fcfc6028ceba9b624d8f4815b418384fbee Mon Sep 17 00:00:00 2001 From: beardyface Date: Tue, 26 Nov 2024 09:19:39 +1300 Subject: [PATCH] SACAE and TD3AE updated to reflect default + base + actor/critic layout - lays groundwork for future generalisation of AE for all methods --- .../algorithm/policy/SACAE.py | 4 +- .../algorithm/policy/TD3AE.py | 8 +- .../networks/SAC/actor.py | 4 +- .../networks/SAC/critic.py | 10 ++- .../networks/SACAE/actor.py | 80 ++++++++++++------- .../networks/SACAE/critic.py | 67 +++++++++++++--- .../networks/TD3/actor.py | 10 ++- .../networks/TD3/critic.py | 22 +++-- .../networks/TD3AE/actor.py | 75 ++++++++++++++--- .../networks/TD3AE/critic.py | 68 +++++++++++++--- .../util/configurations.py | 9 ++- .../util/network_factory.py | 54 ++++++------- 12 files changed, 295 insertions(+), 116 deletions(-) diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index b4649a8..6a73138 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -236,10 +236,10 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: if self.learn_counter % self.target_update_freq == 0: # Update the target networks - Soft Update hlp.soft_update_params( - self.critic_net.Q1, self.target_critic_net.Q1, self.tau + self.critic_net.critic.Q1, self.target_critic_net.critic.Q1, self.tau ) hlp.soft_update_params( - self.critic_net.Q2, self.target_critic_net.Q2, self.tau + self.critic_net.critic.Q2, self.target_critic_net.critic.Q2, self.tau ) hlp.soft_update_params( self.critic_net.encoder, diff --git a/cares_reinforcement_learning/algorithm/policy/TD3AE.py b/cares_reinforcement_learning/algorithm/policy/TD3AE.py index 7ab2f51..12dae8b 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3AE.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3AE.py @@ -199,10 +199,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: # Update target network params hlp.soft_update_params( - self.critic_net.Q1, self.target_critic_net.Q1, self.tau + self.critic_net.critic.Q1, + self.target_critic_net.critic.Q1, + self.tau, ) hlp.soft_update_params( - self.critic_net.Q2, self.target_critic_net.Q2, self.tau + self.critic_net.critic.Q2, + self.target_critic_net.critic.Q2, + self.tau, ) hlp.soft_update_params( diff --git a/cares_reinforcement_learning/networks/SAC/actor.py b/cares_reinforcement_learning/networks/SAC/actor.py index e5a46bb..9dca6d8 100644 --- a/cares_reinforcement_learning/networks/SAC/actor.py +++ b/cares_reinforcement_learning/networks/SAC/actor.py @@ -60,9 +60,11 @@ def __init__( self, observation_size: int, num_actions: int, + hidden_sizes: list[int] | None = None, ): log_std_bounds = [-20.0, 2.0] - hidden_sizes = [256, 256] + if hidden_sizes is None: + hidden_sizes = [256, 256] act_net = nn.Sequential( nn.Linear(observation_size, hidden_sizes[0]), diff --git a/cares_reinforcement_learning/networks/SAC/critic.py b/cares_reinforcement_learning/networks/SAC/critic.py index f0992d8..db0dc9a 100644 --- a/cares_reinforcement_learning/networks/SAC/critic.py +++ b/cares_reinforcement_learning/networks/SAC/critic.py @@ -29,9 +29,15 @@ def forward( # Default network should have this architecture with hidden_sizes = [256, 256]: class DefaultCritic(BaseCritic): - def __init__(self, observation_size: int, num_actions: int): + def __init__( + self, + observation_size: int, + num_actions: int, + hidden_sizes: list[int] | None = None, + ): input_size = observation_size + num_actions - hidden_sizes = [256, 256] + if hidden_sizes is None: + hidden_sizes = [256, 256] # Q1 architecture # pylint: disable-next=invalid-name diff --git a/cares_reinforcement_learning/networks/SACAE/actor.py b/cares_reinforcement_learning/networks/SACAE/actor.py index 11abb43..3c04868 100644 --- a/cares_reinforcement_learning/networks/SACAE/actor.py +++ b/cares_reinforcement_learning/networks/SACAE/actor.py @@ -1,4 +1,5 @@ import torch +from torch import nn import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder @@ -7,19 +8,23 @@ from cares_reinforcement_learning.util.configurations import SACAEConfig -# class BaseActor(DefaultSACActor): -# def __init__(self, observation_size: int, num_actions: int): -# pass - - -class BaseActor: - def __init__(self, encoder: Encoder, actor: SACActor | DefaultSACActor): +class BaseActor(nn.Module): + def __init__( + self, + num_actions: int, + encoder: Encoder, + actor: SACActor | DefaultSACActor, + add_vector_observation: bool = False, + ): super().__init__() + self.num_actions = num_actions self.encoder = encoder self.actor = actor - self.add_vector_observation = False + self.add_vector_observation = add_vector_observation + + self.apply(hlp.weight_init) def forward( # type: ignore self, state: dict[str, torch.Tensor], detach_encoder: bool = False @@ -34,32 +39,49 @@ def forward( # type: ignore return self.actor(actor_input) -class Actor(SACActor): - def __init__( - self, - vector_observation_size: int, - encoder: Encoder, - num_actions: int, - config: SACAEConfig, - ): +class DefaultActor(BaseActor): + def __init__(self, observation_size: dict, num_actions: int): + + encoder = Encoder( + observation_size["image"], + latent_dim=50, + num_layers=4, + num_filters=32, + kernel_size=3, + ) + + actor = DefaultSACActor( + encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024] + ) + super().__init__( - encoder.latent_dim + vector_observation_size, num_actions, config + num_actions, + encoder, + actor, ) - self.encoder = encoder - self.vector_observation_size = vector_observation_size +class Actor(BaseActor): + def __init__(self, observation_size: dict, num_actions: int, config: SACAEConfig): - self.apply(hlp.weight_init) + ae_config = config.autoencoder_config + encoder = Encoder( + observation_size["image"], + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, + ) - def forward( # type: ignore - self, state: dict[str, torch.Tensor], detach_encoder: bool = False - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + actor_observation_size = encoder.latent_dim + if config.vector_observation: + actor_observation_size += observation_size["vector"] - actor_input = state_latent - if self.vector_observation_size > 0: - actor_input = torch.cat([state["vector"], actor_input], dim=1) + actor = SACActor(actor_observation_size, num_actions, config) - return super().forward(actor_input) + super().__init__( + num_actions, + encoder=encoder, + actor=actor, + add_vector_observation=bool(config.vector_observation), + ) diff --git a/cares_reinforcement_learning/networks/SACAE/critic.py b/cares_reinforcement_learning/networks/SACAE/critic.py index d5f4529..20610b5 100644 --- a/cares_reinforcement_learning/networks/SACAE/critic.py +++ b/cares_reinforcement_learning/networks/SACAE/critic.py @@ -1,30 +1,30 @@ import torch +from torch import nn import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder from cares_reinforcement_learning.networks.SAC import Critic as SACCritic +from cares_reinforcement_learning.networks.SAC import DefaultCritic as DefaultSACCritic from cares_reinforcement_learning.util.configurations import SACAEConfig -class Critic(SACCritic): +class BaseCritic(nn.Module): def __init__( self, - vector_observation_size: int, encoder: Encoder, - num_actions: int, - config: SACAEConfig, + critic: SACCritic | DefaultSACCritic, + add_vector_observation: bool = False, ): - super().__init__( - encoder.latent_dim + vector_observation_size, num_actions, config - ) - - self.vector_observation_size = vector_observation_size + super().__init__() self.encoder = encoder + self.critic = critic + + self.add_vector_observation = add_vector_observation self.apply(hlp.weight_init) - def forward( # type: ignore + def forward( self, state: dict[str, torch.Tensor], action: torch.Tensor, @@ -34,7 +34,50 @@ def forward( # type: ignore state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) critic_input = state_latent - if self.vector_observation_size > 0: + if self.add_vector_observation: critic_input = torch.cat([state["vector"], critic_input], dim=1) - return super().forward(critic_input, action) + return self.critic(critic_input, action) + + +class DefaultCritic(BaseCritic): + def __init__(self, observation_size: dict, num_actions: int): + + encoder = Encoder( + observation_size["image"], + latent_dim=50, + num_layers=4, + num_filters=32, + kernel_size=3, + ) + + critic = DefaultSACCritic( + encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024] + ) + + super().__init__(encoder, critic) + + +class Critic(BaseCritic): + def __init__(self, observation_size: dict, num_actions: int, config: SACAEConfig): + + ae_config = config.autoencoder_config + encoder = Encoder( + observation_size["image"], + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, + ) + + critic_observation_size = encoder.latent_dim + if config.vector_observation: + critic_observation_size += observation_size["vector"] + + critic = SACCritic(critic_observation_size, num_actions, config) + + super().__init__( + encoder=encoder, + critic=critic, + add_vector_observation=bool(config.vector_observation), + ) diff --git a/cares_reinforcement_learning/networks/TD3/actor.py b/cares_reinforcement_learning/networks/TD3/actor.py index a250501..48e9deb 100644 --- a/cares_reinforcement_learning/networks/TD3/actor.py +++ b/cares_reinforcement_learning/networks/TD3/actor.py @@ -18,8 +18,14 @@ def forward(self, state: torch.Tensor) -> torch.Tensor: class DefaultActor(BaseActor): - def __init__(self, observation_size: int, num_actions: int): - hidden_sizes = [256, 256] + def __init__( + self, + observation_size: int, + num_actions: int, + hidden_sizes: list[int] | None = None, + ): + if hidden_sizes is None: + hidden_sizes = [256, 256] act_net = nn.Sequential( nn.Linear(observation_size, hidden_sizes[0]), diff --git a/cares_reinforcement_learning/networks/TD3/critic.py b/cares_reinforcement_learning/networks/TD3/critic.py index 2295067..d6959f4 100644 --- a/cares_reinforcement_learning/networks/TD3/critic.py +++ b/cares_reinforcement_learning/networks/TD3/critic.py @@ -28,27 +28,33 @@ def forward( # This is the default base network for TD3 for reference and testing of default network configurations class DefaultCritic(BaseCritic): - def __init__(self, observation_size: int, num_actions: int): - hidden_size = [256, 256] + def __init__( + self, + observation_size: int, + num_actions: int, + hidden_sizes: list[int] | None = None, + ): + if hidden_sizes is None: + hidden_sizes = [256, 256] # Q1 architecture # pylint: disable-next=invalid-name Q1 = nn.Sequential( - nn.Linear(observation_size + num_actions, hidden_size[0]), + nn.Linear(observation_size + num_actions, hidden_sizes[0]), nn.ReLU(), - nn.Linear(hidden_size[0], hidden_size[1]), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), nn.ReLU(), - nn.Linear(hidden_size[1], 1), + nn.Linear(hidden_sizes[1], 1), ) # Q2 architecture # pylint: disable-next=invalid-name Q2 = nn.Sequential( - nn.Linear(observation_size + num_actions, hidden_size[0]), + nn.Linear(observation_size + num_actions, hidden_sizes[0]), nn.ReLU(), - nn.Linear(hidden_size[0], hidden_size[1]), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), nn.ReLU(), - nn.Linear(hidden_size[1], 1), + nn.Linear(hidden_sizes[1], 1), ) super().__init__(Q1=Q1, Q2=Q2) diff --git a/cares_reinforcement_learning/networks/TD3AE/actor.py b/cares_reinforcement_learning/networks/TD3AE/actor.py index e52dfd1..163f5f3 100644 --- a/cares_reinforcement_learning/networks/TD3AE/actor.py +++ b/cares_reinforcement_learning/networks/TD3AE/actor.py @@ -1,38 +1,87 @@ import torch +from torch import nn import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder from cares_reinforcement_learning.networks.TD3 import Actor as TD3Actor +from cares_reinforcement_learning.networks.TD3 import DefaultActor as DefaultTD3Actor from cares_reinforcement_learning.util.configurations import TD3AEConfig -class Actor(TD3Actor): +class BaseActor(nn.Module): def __init__( self, - vector_observation_size: int, - encoder: Encoder, num_actions: int, - config: TD3AEConfig, + encoder: Encoder, + actor: TD3Actor | DefaultTD3Actor, + add_vector_observation: bool = False, ): + super().__init__() - super().__init__( - encoder.latent_dim + vector_observation_size, num_actions, config - ) - + self.num_actions = num_actions self.encoder = encoder + self.actor = actor - self.apply(hlp.weight_init) + self.add_vector_observation = add_vector_observation - self.vector_observation_size = vector_observation_size + self.apply(hlp.weight_init) def forward( # type: ignore self, state: dict[str, torch.Tensor], detach_encoder: bool = False - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) actor_input = state_latent - if self.vector_observation_size > 0: + if self.add_vector_observation: actor_input = torch.cat([state["vector"], actor_input], dim=1) - return super().forward(actor_input) + return self.actor(actor_input) + + +class DefaultActor(BaseActor): + def __init__(self, observation_size: dict, num_actions: int): + + encoder = Encoder( + observation_size["image"], + latent_dim=50, + num_layers=4, + num_filters=32, + kernel_size=3, + ) + + actor = DefaultTD3Actor( + encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024] + ) + + super().__init__( + num_actions, + encoder, + actor, + ) + + +class Actor(BaseActor): + def __init__(self, observation_size: dict, num_actions: int, config: TD3AEConfig): + + ae_config = config.autoencoder_config + encoder = Encoder( + observation_size["image"], + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, + ) + + actor_observation_size = encoder.latent_dim + if config.vector_observation: + actor_observation_size += observation_size["vector"] + + actor = TD3Actor(actor_observation_size, num_actions, config) + + super().__init__( + num_actions, + encoder=encoder, + actor=actor, + add_vector_observation=bool(config.vector_observation), + ) diff --git a/cares_reinforcement_learning/networks/TD3AE/critic.py b/cares_reinforcement_learning/networks/TD3AE/critic.py index f6b5570..dedb237 100644 --- a/cares_reinforcement_learning/networks/TD3AE/critic.py +++ b/cares_reinforcement_learning/networks/TD3AE/critic.py @@ -1,31 +1,30 @@ import torch +from torch import nn import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.encoders.vanilla_autoencoder import Encoder +from cares_reinforcement_learning.networks.TD3 import DefaultCritic as DefaultTD3Critic from cares_reinforcement_learning.networks.TD3 import Critic as TD3Critic from cares_reinforcement_learning.util.configurations import TD3AEConfig -class Critic(TD3Critic): +class BaseCritic(nn.Module): def __init__( self, - vector_observation_size: int, encoder: Encoder, - num_actions: int, - config: TD3AEConfig, + critic: TD3Critic | DefaultTD3Critic, + add_vector_observation: bool = False, ): - - super().__init__( - encoder.latent_dim + vector_observation_size, num_actions, config - ) - - self.vector_observation_size = vector_observation_size + super().__init__() self.encoder = encoder + self.critic = critic + + self.add_vector_observation = add_vector_observation self.apply(hlp.weight_init) - def forward( # type: ignore + def forward( self, state: dict[str, torch.Tensor], action: torch.Tensor, @@ -35,7 +34,50 @@ def forward( # type: ignore state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) critic_input = state_latent - if self.vector_observation_size > 0: + if self.add_vector_observation: critic_input = torch.cat([state["vector"], critic_input], dim=1) - return super().forward(critic_input, action) + return self.critic(critic_input, action) + + +class DefaultCritic(BaseCritic): + def __init__(self, observation_size: dict, num_actions: int): + + encoder = Encoder( + observation_size["image"], + latent_dim=50, + num_layers=4, + num_filters=32, + kernel_size=3, + ) + + critic = DefaultTD3Critic( + encoder.latent_dim, num_actions, hidden_sizes=[1024, 1024] + ) + + super().__init__(encoder, critic) + + +class Critic(BaseCritic): + def __init__(self, observation_size: dict, num_actions: int, config: TD3AEConfig): + + ae_config = config.autoencoder_config + encoder = Encoder( + observation_size["image"], + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, + ) + + critic_observation_size = encoder.latent_dim + if config.vector_observation: + critic_observation_size += observation_size["vector"] + + critic = TD3Critic(critic_observation_size, num_actions, config) + + super().__init__( + encoder=encoder, + critic=critic, + add_vector_observation=bool(config.vector_observation), + ) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index af20169..1fe6b39 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -86,9 +86,12 @@ class AlgorithmConfig(SubscriptableClass): @root_validator(pre=True) def convert_none_to_dict(cls, values): # pylint: disable-next=no-self-argument - for field, value in values.items(): - if cls.__annotations__.get(field) == dict and value is None: - values[field] = {} + if values.get("norm_layer_args") is None: + values["norm_layer_args"] = {} + if values.get("activation_function_args") is None: + values["activation_function_args"] = {} + if values.get("final_activation_args") is None: + values["final_activation_args"] = {} return values diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 0f4d40d..18a1c05 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -8,8 +8,8 @@ import logging import sys -import cares_reinforcement_learning.util.helpers as hlp import cares_reinforcement_learning.util.configurations as acf +import cares_reinforcement_learning.util.helpers as hlp # Disable these as this is a deliberate use of dynamic imports # pylint: disable=import-outside-toplevel @@ -123,29 +123,27 @@ def create_SAC(observation_size, action_num, config: acf.SACConfig): def create_SACAE(observation_size, action_num, config: acf.SACAEConfig): from cares_reinforcement_learning.algorithm.policy import SACAE - from cares_reinforcement_learning.encoders.autoencoder_factory import AEFactory + from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder from cares_reinforcement_learning.networks.SACAE import Actor, Critic - ae_factory = AEFactory() - autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size["image"], config=config.autoencoder_config - ) - - actor_encoder = copy.deepcopy(autoencoder.encoder) - critic_encoder = copy.deepcopy(autoencoder.encoder) + actor = Actor(observation_size, action_num, config=config) + critic = Critic(observation_size, action_num, config=config) - vector_observation_size = ( - observation_size["vector"] if config.vector_observation else 0 + ae_config = config.autoencoder_config + decoder = Decoder( + observation_size["image"], + out_dim=actor.encoder.out_dim, + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, ) - actor = Actor(vector_observation_size, actor_encoder, action_num, config=config) - critic = Critic(vector_observation_size, critic_encoder, action_num, config=config) - device = hlp.get_device() agent = SACAE( actor_network=actor, critic_network=critic, - decoder_network=autoencoder.decoder, + decoder_network=decoder, config=config, device=device, ) @@ -344,29 +342,27 @@ def create_TD3(observation_size, action_num, config: acf.TD3Config): def create_TD3AE(observation_size, action_num, config: acf.TD3AEConfig): from cares_reinforcement_learning.algorithm.policy import TD3AE - from cares_reinforcement_learning.encoders.autoencoder_factory import AEFactory + from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder from cares_reinforcement_learning.networks.TD3AE import Actor, Critic - ae_factory = AEFactory() - autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size["image"], config=config.autoencoder_config - ) - - actor_encoder = copy.deepcopy(autoencoder.encoder) - critic_encoder = copy.deepcopy(autoencoder.encoder) + actor = Actor(observation_size, action_num, config=config) + critic = Critic(observation_size, action_num, config=config) - vector_observation_size = ( - observation_size["vector"] if config.vector_observation else 0 + ae_config = config.autoencoder_config + decoder = Decoder( + observation_size["image"], + out_dim=actor.encoder.out_dim, + latent_dim=ae_config.latent_dim, + num_layers=ae_config.num_layers, + num_filters=ae_config.num_filters, + kernel_size=ae_config.kernel_size, ) - actor = Actor(vector_observation_size, actor_encoder, action_num, config=config) - critic = Critic(vector_observation_size, critic_encoder, action_num, config=config) - device = hlp.get_device() agent = TD3AE( actor_network=actor, critic_network=critic, - decoder_network=autoencoder.decoder, + decoder_network=decoder, config=config, device=device, )