Skip to content

Commit

Permalink
autoencoder type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Oct 10, 2024
1 parent d4c93b3 commit 054be40
Show file tree
Hide file tree
Showing 33 changed files with 59 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

[mypy]
exclude = build
disable_error_code = import-untyped
# disable_error_code = import-untyped
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
actor_network: nn.Module,
critic_network: nn.Module,
config: NaSATD3Config,
device: str,
device: torch.device,
):
self.type = "policy"
self.device = device
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/algorithm/policy/TQC.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
config: TQCConfig,
device: str,
device: torch.device,
):
self.type = "policy"

Expand Down
40 changes: 16 additions & 24 deletions cares_reinforcement_learning/encoders/autoencoder_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging

import cares_reinforcement_learning.encoders.configurations as acf
from cares_reinforcement_learning.encoders import losses
from cares_reinforcement_learning.encoders.autoencoder import Autoencoder
from cares_reinforcement_learning.encoders.configurations import AEConfig
from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder
from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder

# Disable these as this is a deliberate use of dynamic imports
# pylint: disable=import-outside-toplevel
Expand All @@ -11,11 +10,8 @@

def create_vanilla_autoencoder(
observation_size: tuple[int],
config: AEConfig,
) -> Autoencoder:
from cares_reinforcement_learning.encoders.vanilla_autoencoder import (
VanillaAutoencoder,
)
config: acf.VanillaAEConfig,
) -> VanillaAutoencoder:

return VanillaAutoencoder(
observation_size=observation_size,
Expand All @@ -31,11 +27,8 @@ def create_vanilla_autoencoder(

def create_burgess_autoencoder(
observation_size: tuple[int],
config: AEConfig,
) -> Autoencoder:
from cares_reinforcement_learning.encoders.burgess_autoencoder import (
BurgessAutoencoder,
)
config: acf.BurgessConfig,
) -> BurgessAutoencoder:

loss_function = losses.get_burgess_loss_function(config)

Expand All @@ -55,16 +48,15 @@ class AEFactory:
def create_autoencoder(
self,
observation_size: tuple[int],
config: AEConfig,
) -> Autoencoder:
config: acf.VanillaAEConfig | acf.BurgessConfig,
) -> BurgessAutoencoder | VanillaAutoencoder:

autoencoder = None
if config.type == "vanilla":
autoencoder = create_vanilla_autoencoder(observation_size, config)
elif config.type == "burgess":
autoencoder = create_burgess_autoencoder(observation_size, config)
if isinstance(config, acf.VanillaAEConfig):
return create_vanilla_autoencoder(observation_size, config)

if autoencoder is None:
logging.warning(f"Unkown autoencoder {autoencoder}.")
if isinstance(config, acf.BurgessConfig):
return create_burgess_autoencoder(observation_size, config)

return autoencoder
raise ValueError(
f"Invalid autoencoder configuration: {type(config)=} {config=}"
)
6 changes: 4 additions & 2 deletions cares_reinforcement_learning/encoders/vanilla_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import torch
from torch import nn

Expand Down Expand Up @@ -51,8 +53,8 @@ def __init__(
num_filters: int = 32,
kernel_size: int = 3,
latent_lambda: float = 1e-6,
encoder_optimiser_params: dict[str, any] = None,
decoder_optimiser_params: dict[str, any] = None,
encoder_optimiser_params: dict[str, Any] | None = None,
decoder_optimiser_params: dict[str, Any] | None = None,
):
if encoder_optimiser_params is None:
encoder_optimiser_params = {"lr": 1e-4}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
action_num: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
ensemble_size: int,
observation_size: int,
action_num: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/DDPG/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/DDPG/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/DQN/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/DoubleDQN/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_space_size: int,
action_num: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/MAPERTD3/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/NaSATD3/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
vector_observation_size: int,
num_actions: int,
autoencoder: Autoencoder,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/NaSATD3/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
vector_observation_size: int,
num_actions: int,
autoencoder: Autoencoder,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/PPO/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/PPO/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Critic(nn.Module):
def __init__(
self,
observation_size: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/RDSAC/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/RDTD3/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/REDQ/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
4 changes: 2 additions & 2 deletions cares_reinforcement_learning/networks/SAC/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
log_std_bounds: list[int] = None,
hidden_size: list[int] | None = None,
log_std_bounds: list[float] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/SAC/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
4 changes: 2 additions & 2 deletions cares_reinforcement_learning/networks/SACAE/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
log_std_bounds: list[int] = None,
hidden_size: list[int] | None = None,
log_std_bounds: list[float] | None = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/SACAE/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/TD3/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/TD3/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/TD3AE/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/TD3AE/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]
Expand Down
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/TQC/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
num_actions: int,
num_quantiles: int,
num_critics: int,
hidden_size: list[int] = None,
hidden_size: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
num_actions: int,
num_models: int,
lr: float,
device: str,
device: torch.device,
hidden_size: int = 128,
):
self.num_models = num_models
Expand Down
3 changes: 2 additions & 1 deletion cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AEConfig,
VAEConfig,
VanillaAEConfig,
BurgessConfig,
)

# pylint disbale-next=unused-import
Expand Down Expand Up @@ -270,7 +271,7 @@ class NaSATD3Config(AlgorithmConfig):

vector_observation: int = 0

autoencoder_config: AEConfig = VanillaAEConfig(
autoencoder_config: VanillaAEConfig | BurgessConfig = VanillaAEConfig(
latent_dim=200,
num_layers=4,
num_filters=32,
Expand Down
1 change: 1 addition & 0 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def create_LA3PSAC(observation_size, action_num, config: acf.LA3PSACConfig):
return agent


# TODO return type base "Algorithm" class?
class NetworkFactory:
def create_network(
self,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def test_ae():

config = config(latent_dim=100)

autoencoder = factory.create_autoencoder(
observation_size=observation_size, config=config
)
assert autoencoder is not None, f"{ae} was not created successfully"
try:
autoencoder = factory.create_autoencoder(
observation_size=observation_size, config=config
)
except Exception as e:
pytest.fail(f"Exception making autoencoder: {ae} {e}")

autoencoder = autoencoder.to(device)

Expand Down

0 comments on commit 054be40

Please sign in to comment.