Skip to content

Commit

Permalink
Type hint for types in networks
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Nov 25, 2024
1 parent d2679fc commit 56be414
Show file tree
Hide file tree
Showing 28 changed files with 81 additions and 54 deletions.
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.DynaSAC import Actor, Critic
from cares_reinforcement_learning.networks.world_models.ensemble_integrated import (
EnsembleWorldReward,
)
Expand All @@ -25,8 +26,8 @@
class DynaSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
world_network: EnsembleWorldReward,
config: DynaSACConfig,
device: torch.device,
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/CTD4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.CTD4 import Actor, EnsembleCritic
from cares_reinforcement_learning.util.configurations import CTD4Config


class CTD4:
def __init__(
self,
actor_network: torch.nn.Module,
ensemble_critics: torch.nn.ModuleList,
actor_network: Actor,
ensemble_critics: EnsembleCritic,
config: CTD4Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.DDPG import Actor, Critic
from cares_reinforcement_learning.util.configurations import DDPGConfig


class DDPG:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: DDPGConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/LA3PSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import LA3PSACConfig
from cares_reinforcement_learning.networks.LA3PSAC import Actor, Critic


class LA3PSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: LA3PSACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/LA3PTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.LA3PTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import LA3PTD3Config


class LA3PTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: LA3PTD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/LAPSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.LAPSAC import Actor, Critic
from cares_reinforcement_learning.util.configurations import LAPSACConfig


class LAPSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: LAPSACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/LAPTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.LAPTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import LAPTD3Config


class LAPTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: LAPTD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/MAPERSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.MAPERSAC import Actor, Critic
from cares_reinforcement_learning.util.configurations import MAPERSACConfig


class MAPERSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: MAPERSACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/MAPERTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.MAPERTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import MAPERTD3Config


class MAPERTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: MAPERTD3Config,
device: torch.device,
):
Expand Down
9 changes: 5 additions & 4 deletions cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@
from torch import nn

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder
from cares_reinforcement_learning.encoders.constants import Autoencoders
from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.NaSATD3 import Actor, Critic
from cares_reinforcement_learning.networks.NaSATD3.EPDM import EPDM
from cares_reinforcement_learning.util.configurations import NaSATD3Config
from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder
from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder


class NaSATD3:
def __init__(
self,
autoencoder: VanillaAutoencoder | BurgessAutoencoder,
actor_network: nn.Module,
critic_network: nn.Module,
actor_network: Actor,
critic_network: Critic,
config: NaSATD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/PALTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.PALTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import PALTD3Config


class PALTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: PALTD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/PERSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.PERSAC import Actor, Critic
from cares_reinforcement_learning.util.configurations import PERSACConfig


class PERSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: PERSACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/PERTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.PERTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import PERTD3Config


class PERTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: PERTD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/PPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from torch.distributions import MultivariateNormal

from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.PPO import Actor, Critic
from cares_reinforcement_learning.util.configurations import PPOConfig


class PPO:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: PPOConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/RDSAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.RDSAC import Actor, Critic
from cares_reinforcement_learning.util.configurations import RDSACConfig


class RDSAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: RDSACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/RDTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.RDTD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import RDTD3Config


class RDTD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: RDTD3Config,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/REDQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.REDQ import Actor, EnsembleCritic
from cares_reinforcement_learning.util.configurations import REDQConfig


class REDQ:
def __init__(
self,
actor_network: torch.nn.Module,
ensemble_critics: torch.nn.ModuleList,
actor_network: Actor,
ensemble_critics: EnsembleCritic,
config: REDQConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.SAC import Actor, Critic
from cares_reinforcement_learning.util.configurations import SACConfig


class SAC:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: SACConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.SACAE import Actor, Critic
from cares_reinforcement_learning.util.configurations import SACAEConfig


class SACAE:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
decoder_network: Decoder,
config: SACAEConfig,
device: torch.device,
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/SACD.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.SACD import Actor, Critic
from cares_reinforcement_learning.util.configurations import SACDConfig


class SACD:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: SACDConfig,
device: torch.device,
):
Expand Down
5 changes: 3 additions & 2 deletions cares_reinforcement_learning/algorithm/policy/TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.networks.TD3 import Actor, Critic
from cares_reinforcement_learning.util.configurations import TD3Config


class TD3:
def __init__(
self,
actor_network: torch.nn.Module,
critic_network: torch.nn.Module,
actor_network: Actor,
critic_network: Critic,
config: TD3Config,
device: torch.device,
):
Expand Down
Loading

0 comments on commit 56be414

Please sign in to comment.