diff --git a/.mypy.ini b/.mypy.ini index 493ef9d8..8ce00315 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -2,4 +2,4 @@ [mypy] exclude = build -disable_error_code = import-untyped \ No newline at end of file +# disable_error_code = import-untyped \ No newline at end of file diff --git a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py index b439aff0..e19b1d90 100644 --- a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py @@ -25,7 +25,7 @@ def __init__( actor_network: nn.Module, critic_network: nn.Module, config: NaSATD3Config, - device: str, + device: torch.device, ): self.type = "policy" self.device = device diff --git a/cares_reinforcement_learning/algorithm/policy/TQC.py b/cares_reinforcement_learning/algorithm/policy/TQC.py index 3458b561..efbfbe6c 100644 --- a/cares_reinforcement_learning/algorithm/policy/TQC.py +++ b/cares_reinforcement_learning/algorithm/policy/TQC.py @@ -24,7 +24,7 @@ def __init__( actor_network: torch.nn.Module, critic_network: torch.nn.Module, config: TQCConfig, - device: str, + device: torch.device, ): self.type = "policy" diff --git a/cares_reinforcement_learning/encoders/autoencoder_factory.py b/cares_reinforcement_learning/encoders/autoencoder_factory.py index 564d1a36..0df8cb56 100644 --- a/cares_reinforcement_learning/encoders/autoencoder_factory.py +++ b/cares_reinforcement_learning/encoders/autoencoder_factory.py @@ -1,8 +1,7 @@ -import logging - +import cares_reinforcement_learning.encoders.configurations as acf from cares_reinforcement_learning.encoders import losses -from cares_reinforcement_learning.encoders.autoencoder import Autoencoder -from cares_reinforcement_learning.encoders.configurations import AEConfig +from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder +from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder # Disable these as this is a deliberate use of dynamic imports # pylint: disable=import-outside-toplevel @@ -11,11 +10,8 @@ def create_vanilla_autoencoder( observation_size: tuple[int], - config: AEConfig, -) -> Autoencoder: - from cares_reinforcement_learning.encoders.vanilla_autoencoder import ( - VanillaAutoencoder, - ) + config: acf.VanillaAEConfig, +) -> VanillaAutoencoder: return VanillaAutoencoder( observation_size=observation_size, @@ -31,11 +27,8 @@ def create_vanilla_autoencoder( def create_burgess_autoencoder( observation_size: tuple[int], - config: AEConfig, -) -> Autoencoder: - from cares_reinforcement_learning.encoders.burgess_autoencoder import ( - BurgessAutoencoder, - ) + config: acf.BurgessConfig, +) -> BurgessAutoencoder: loss_function = losses.get_burgess_loss_function(config) @@ -55,16 +48,15 @@ class AEFactory: def create_autoencoder( self, observation_size: tuple[int], - config: AEConfig, - ) -> Autoencoder: + config: acf.VanillaAEConfig | acf.BurgessConfig, + ) -> BurgessAutoencoder | VanillaAutoencoder: - autoencoder = None - if config.type == "vanilla": - autoencoder = create_vanilla_autoencoder(observation_size, config) - elif config.type == "burgess": - autoencoder = create_burgess_autoencoder(observation_size, config) + if isinstance(config, acf.VanillaAEConfig): + return create_vanilla_autoencoder(observation_size, config) - if autoencoder is None: - logging.warning(f"Unkown autoencoder {autoencoder}.") + if isinstance(config, acf.BurgessConfig): + return create_burgess_autoencoder(observation_size, config) - return autoencoder + raise ValueError( + f"Invalid autoencoder configuration: {type(config)=} {config=}" + ) diff --git a/cares_reinforcement_learning/encoders/vanilla_autoencoder.py b/cares_reinforcement_learning/encoders/vanilla_autoencoder.py index 9dbd1129..325783e6 100644 --- a/cares_reinforcement_learning/encoders/vanilla_autoencoder.py +++ b/cares_reinforcement_learning/encoders/vanilla_autoencoder.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import nn @@ -51,8 +53,8 @@ def __init__( num_filters: int = 32, kernel_size: int = 3, latent_lambda: float = 1e-6, - encoder_optimiser_params: dict[str, any] = None, - decoder_optimiser_params: dict[str, any] = None, + encoder_optimiser_params: dict[str, Any] | None = None, + decoder_optimiser_params: dict[str, Any] | None = None, ): if encoder_optimiser_params is None: encoder_optimiser_params = {"lr": 1e-4} diff --git a/cares_reinforcement_learning/networks/CTD4/distributed_critic.py b/cares_reinforcement_learning/networks/CTD4/distributed_critic.py index 537fa81f..56a4a606 100644 --- a/cares_reinforcement_learning/networks/CTD4/distributed_critic.py +++ b/cares_reinforcement_learning/networks/CTD4/distributed_critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, action_num: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/CTD4/ensemble_critic.py b/cares_reinforcement_learning/networks/CTD4/ensemble_critic.py index 3e2ee81a..1b81e669 100644 --- a/cares_reinforcement_learning/networks/CTD4/ensemble_critic.py +++ b/cares_reinforcement_learning/networks/CTD4/ensemble_critic.py @@ -10,7 +10,7 @@ def __init__( ensemble_size: int, observation_size: int, action_num: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/DDPG/actor.py b/cares_reinforcement_learning/networks/DDPG/actor.py index 5a4cd153..262f46e7 100644 --- a/cares_reinforcement_learning/networks/DDPG/actor.py +++ b/cares_reinforcement_learning/networks/DDPG/actor.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/DDPG/critic.py b/cares_reinforcement_learning/networks/DDPG/critic.py index dd92fdf9..591816cc 100644 --- a/cares_reinforcement_learning/networks/DDPG/critic.py +++ b/cares_reinforcement_learning/networks/DDPG/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/DQN/network.py b/cares_reinforcement_learning/networks/DQN/network.py index 042f2f5f..ea582eb3 100644 --- a/cares_reinforcement_learning/networks/DQN/network.py +++ b/cares_reinforcement_learning/networks/DQN/network.py @@ -8,7 +8,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/DoubleDQN/network.py b/cares_reinforcement_learning/networks/DoubleDQN/network.py index 042f2f5f..ea582eb3 100644 --- a/cares_reinforcement_learning/networks/DoubleDQN/network.py +++ b/cares_reinforcement_learning/networks/DoubleDQN/network.py @@ -8,7 +8,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/DuelingDQN/network.py b/cares_reinforcement_learning/networks/DuelingDQN/network.py index 415a637e..9fc939d3 100644 --- a/cares_reinforcement_learning/networks/DuelingDQN/network.py +++ b/cares_reinforcement_learning/networks/DuelingDQN/network.py @@ -7,7 +7,7 @@ def __init__( self, observation_space_size: int, action_num: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/MAPERTD3/critic.py b/cares_reinforcement_learning/networks/MAPERTD3/critic.py index 5bc0f747..40e90299 100644 --- a/cares_reinforcement_learning/networks/MAPERTD3/critic.py +++ b/cares_reinforcement_learning/networks/MAPERTD3/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/NaSATD3/actor.py b/cares_reinforcement_learning/networks/NaSATD3/actor.py index ac34a17f..9d8f5d40 100644 --- a/cares_reinforcement_learning/networks/NaSATD3/actor.py +++ b/cares_reinforcement_learning/networks/NaSATD3/actor.py @@ -14,7 +14,7 @@ def __init__( vector_observation_size: int, num_actions: int, autoencoder: Autoencoder, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/NaSATD3/critic.py b/cares_reinforcement_learning/networks/NaSATD3/critic.py index d4934999..7951e77c 100644 --- a/cares_reinforcement_learning/networks/NaSATD3/critic.py +++ b/cares_reinforcement_learning/networks/NaSATD3/critic.py @@ -14,7 +14,7 @@ def __init__( vector_observation_size: int, num_actions: int, autoencoder: Autoencoder, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/PPO/actor.py b/cares_reinforcement_learning/networks/PPO/actor.py index 5a4cd153..262f46e7 100644 --- a/cares_reinforcement_learning/networks/PPO/actor.py +++ b/cares_reinforcement_learning/networks/PPO/actor.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/PPO/critic.py b/cares_reinforcement_learning/networks/PPO/critic.py index 03ba8d24..0bf2c88d 100644 --- a/cares_reinforcement_learning/networks/PPO/critic.py +++ b/cares_reinforcement_learning/networks/PPO/critic.py @@ -6,7 +6,7 @@ class Critic(nn.Module): def __init__( self, observation_size: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/RDSAC/critic.py b/cares_reinforcement_learning/networks/RDSAC/critic.py index 666b8f13..63546dc2 100644 --- a/cares_reinforcement_learning/networks/RDSAC/critic.py +++ b/cares_reinforcement_learning/networks/RDSAC/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/RDTD3/critic.py b/cares_reinforcement_learning/networks/RDTD3/critic.py index 002ca04f..e53fdeab 100644 --- a/cares_reinforcement_learning/networks/RDTD3/critic.py +++ b/cares_reinforcement_learning/networks/RDTD3/critic.py @@ -8,7 +8,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/REDQ/critic.py b/cares_reinforcement_learning/networks/REDQ/critic.py index 980aeb01..fe922795 100644 --- a/cares_reinforcement_learning/networks/REDQ/critic.py +++ b/cares_reinforcement_learning/networks/REDQ/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/SAC/actor.py b/cares_reinforcement_learning/networks/SAC/actor.py index 6113c1b7..bdae0a02 100644 --- a/cares_reinforcement_learning/networks/SAC/actor.py +++ b/cares_reinforcement_learning/networks/SAC/actor.py @@ -12,8 +12,8 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, - log_std_bounds: list[int] = None, + hidden_size: list[int] | None = None, + log_std_bounds: list[float] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/SAC/critic.py b/cares_reinforcement_learning/networks/SAC/critic.py index 2b47b53b..1810a70d 100644 --- a/cares_reinforcement_learning/networks/SAC/critic.py +++ b/cares_reinforcement_learning/networks/SAC/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/SACAE/actor.py b/cares_reinforcement_learning/networks/SACAE/actor.py index 46928189..8a720973 100644 --- a/cares_reinforcement_learning/networks/SACAE/actor.py +++ b/cares_reinforcement_learning/networks/SACAE/actor.py @@ -11,8 +11,8 @@ def __init__( vector_observation_size: int, encoder: Encoder, num_actions: int, - hidden_size: list[int] = None, - log_std_bounds: list[int] = None, + hidden_size: list[int] | None = None, + log_std_bounds: list[float] | None = None, ): if hidden_size is None: hidden_size = [1024, 1024] diff --git a/cares_reinforcement_learning/networks/SACAE/critic.py b/cares_reinforcement_learning/networks/SACAE/critic.py index 5d9e59c4..ece681f8 100644 --- a/cares_reinforcement_learning/networks/SACAE/critic.py +++ b/cares_reinforcement_learning/networks/SACAE/critic.py @@ -11,7 +11,7 @@ def __init__( vector_observation_size: int, encoder: Encoder, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): if hidden_size is None: hidden_size = [1024, 1024] diff --git a/cares_reinforcement_learning/networks/TD3/actor.py b/cares_reinforcement_learning/networks/TD3/actor.py index 35acb566..900931c7 100644 --- a/cares_reinforcement_learning/networks/TD3/actor.py +++ b/cares_reinforcement_learning/networks/TD3/actor.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/TD3/critic.py b/cares_reinforcement_learning/networks/TD3/critic.py index 2b47b53b..1810a70d 100644 --- a/cares_reinforcement_learning/networks/TD3/critic.py +++ b/cares_reinforcement_learning/networks/TD3/critic.py @@ -7,7 +7,7 @@ def __init__( self, observation_size: int, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/TD3AE/actor.py b/cares_reinforcement_learning/networks/TD3AE/actor.py index 75ddf3af..610725e8 100644 --- a/cares_reinforcement_learning/networks/TD3AE/actor.py +++ b/cares_reinforcement_learning/networks/TD3AE/actor.py @@ -11,7 +11,7 @@ def __init__( vector_observation_size: int, encoder: Encoder, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): if hidden_size is None: hidden_size = [1024, 1024] diff --git a/cares_reinforcement_learning/networks/TD3AE/critic.py b/cares_reinforcement_learning/networks/TD3AE/critic.py index 7401def9..4be9b619 100644 --- a/cares_reinforcement_learning/networks/TD3AE/critic.py +++ b/cares_reinforcement_learning/networks/TD3AE/critic.py @@ -11,7 +11,7 @@ def __init__( vector_observation_size: int, encoder: Encoder, num_actions: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): if hidden_size is None: hidden_size = [1024, 1024] diff --git a/cares_reinforcement_learning/networks/TQC/critic.py b/cares_reinforcement_learning/networks/TQC/critic.py index 4265a574..cc7fe094 100644 --- a/cares_reinforcement_learning/networks/TQC/critic.py +++ b/cares_reinforcement_learning/networks/TQC/critic.py @@ -11,7 +11,7 @@ def __init__( num_actions: int, num_quantiles: int, num_critics: int, - hidden_size: list[int] = None, + hidden_size: list[int] | None = None, ): super().__init__() if hidden_size is None: diff --git a/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py b/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py index b31a31ae..9b4c3f3e 100644 --- a/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py +++ b/cares_reinforcement_learning/networks/world_models/ensemble_integrated.py @@ -145,7 +145,7 @@ def __init__( num_actions: int, num_models: int, lr: float, - device: str, + device: torch.device, hidden_size: int = 128, ): self.num_models = num_models diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index a8fbe6b6..088daee3 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -4,6 +4,7 @@ AEConfig, VAEConfig, VanillaAEConfig, + BurgessConfig, ) # pylint disbale-next=unused-import @@ -270,7 +271,7 @@ class NaSATD3Config(AlgorithmConfig): vector_observation: int = 0 - autoencoder_config: AEConfig = VanillaAEConfig( + autoencoder_config: VanillaAEConfig | BurgessConfig = VanillaAEConfig( latent_dim=200, num_layers=4, num_filters=32, diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 2caa2a1f..b89f74f8 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -588,6 +588,7 @@ def create_LA3PSAC(observation_size, action_num, config: acf.LA3PSACConfig): return agent +# TODO return type base "Algorithm" class? class NetworkFactory: def create_network( self, diff --git a/tests/test_ae.py b/tests/test_ae.py index 64380692..f45e1a75 100644 --- a/tests/test_ae.py +++ b/tests/test_ae.py @@ -29,10 +29,12 @@ def test_ae(): config = config(latent_dim=100) - autoencoder = factory.create_autoencoder( - observation_size=observation_size, config=config - ) - assert autoencoder is not None, f"{ae} was not created successfully" + try: + autoencoder = factory.create_autoencoder( + observation_size=observation_size, config=config + ) + except Exception as e: + pytest.fail(f"Exception making autoencoder: {ae} {e}") autoencoder = autoencoder.to(device)