From 56be414baa8945a264fad89a2aa6948635c06c4f Mon Sep 17 00:00:00 2001 From: beardyface Date: Tue, 26 Nov 2024 09:54:22 +1300 Subject: [PATCH] Type hint for types in networks --- cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/CTD4.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/DDPG.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/LA3PSAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/LA3PTD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/LAPSAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/LAPTD3.py | 5 +++-- .../algorithm/policy/MAPERSAC.py | 5 +++-- .../algorithm/policy/MAPERTD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/NaSATD3.py | 9 +++++---- cares_reinforcement_learning/algorithm/policy/PALTD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/PERSAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/PERTD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/PPO.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/RDSAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/RDTD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/REDQ.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/SAC.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/SACAE.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/SACD.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/TD3.py | 5 +++-- cares_reinforcement_learning/algorithm/policy/TD3AE.py | 9 ++++++--- cares_reinforcement_learning/algorithm/policy/TQC.py | 5 +++-- cares_reinforcement_learning/algorithm/value/DQN.py | 6 +++++- .../algorithm/value/DoubleDQN.py | 3 ++- cares_reinforcement_learning/networks/MAPERTD3/critic.py | 1 - cares_reinforcement_learning/networks/RDSAC/critic.py | 1 - cares_reinforcement_learning/networks/RDTD3/critic.py | 1 - 28 files changed, 81 insertions(+), 54 deletions(-) diff --git a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py index 1bc8188..e5e2abc 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py +++ b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py @@ -16,6 +16,7 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.DynaSAC import Actor, Critic from cares_reinforcement_learning.networks.world_models.ensemble_integrated import ( EnsembleWorldReward, ) @@ -25,8 +26,8 @@ class DynaSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, world_network: EnsembleWorldReward, config: DynaSACConfig, device: torch.device, diff --git a/cares_reinforcement_learning/algorithm/policy/CTD4.py b/cares_reinforcement_learning/algorithm/policy/CTD4.py index 87dc751..552b034 100644 --- a/cares_reinforcement_learning/algorithm/policy/CTD4.py +++ b/cares_reinforcement_learning/algorithm/policy/CTD4.py @@ -17,14 +17,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.CTD4 import Actor, EnsembleCritic from cares_reinforcement_learning.util.configurations import CTD4Config class CTD4: def __init__( self, - actor_network: torch.nn.Module, - ensemble_critics: torch.nn.ModuleList, + actor_network: Actor, + ensemble_critics: EnsembleCritic, config: CTD4Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/DDPG.py b/cares_reinforcement_learning/algorithm/policy/DDPG.py index 96ad313..574a806 100644 --- a/cares_reinforcement_learning/algorithm/policy/DDPG.py +++ b/cares_reinforcement_learning/algorithm/policy/DDPG.py @@ -13,14 +13,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.DDPG import Actor, Critic from cares_reinforcement_learning.util.configurations import DDPGConfig class DDPG: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: DDPGConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py index d5da124..23186f8 100644 --- a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py @@ -15,13 +15,14 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer from cares_reinforcement_learning.util.configurations import LA3PSACConfig +from cares_reinforcement_learning.networks.LA3PSAC import Actor, Critic class LA3PSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: LA3PSACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py b/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py index 8dacb43..d7f918a 100644 --- a/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py @@ -14,14 +14,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.LA3PTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import LA3PTD3Config class LA3PTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: LA3PTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py index f38afd7..efdc89d 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py @@ -12,14 +12,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.LAPSAC import Actor, Critic from cares_reinforcement_learning.util.configurations import LAPSACConfig class LAPSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: LAPSACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/LAPTD3.py b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py index 6a90a03..5b2c660 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py @@ -12,14 +12,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.LAPTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import LAPTD3Config class LAPTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: LAPTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py index ca083fe..027d37a 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py @@ -16,14 +16,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.MAPERSAC import Actor, Critic from cares_reinforcement_learning.util.configurations import MAPERSACConfig class MAPERSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: MAPERSACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py b/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py index 4e00346..ac7b57b 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py @@ -16,14 +16,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.MAPERTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import MAPERTD3Config class MAPERTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: MAPERTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py index 33c6124..8142868 100644 --- a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py @@ -12,20 +12,21 @@ from torch import nn import cares_reinforcement_learning.util.helpers as hlp +from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder from cares_reinforcement_learning.encoders.constants import Autoencoders +from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.NaSATD3 import Actor, Critic from cares_reinforcement_learning.networks.NaSATD3.EPDM import EPDM from cares_reinforcement_learning.util.configurations import NaSATD3Config -from cares_reinforcement_learning.encoders.vanilla_autoencoder import VanillaAutoencoder -from cares_reinforcement_learning.encoders.burgess_autoencoder import BurgessAutoencoder class NaSATD3: def __init__( self, autoencoder: VanillaAutoencoder | BurgessAutoencoder, - actor_network: nn.Module, - critic_network: nn.Module, + actor_network: Actor, + critic_network: Critic, config: NaSATD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/PALTD3.py b/cares_reinforcement_learning/algorithm/policy/PALTD3.py index 3737a39..d8bc93b 100644 --- a/cares_reinforcement_learning/algorithm/policy/PALTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/PALTD3.py @@ -12,14 +12,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.PALTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import PALTD3Config class PALTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: PALTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/PERSAC.py b/cares_reinforcement_learning/algorithm/policy/PERSAC.py index 591f27b..ec5796c 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/PERSAC.py @@ -13,14 +13,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.PERSAC import Actor, Critic from cares_reinforcement_learning.util.configurations import PERSACConfig class PERSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: PERSACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/PERTD3.py b/cares_reinforcement_learning/algorithm/policy/PERTD3.py index b122d22..2d8219b 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/PERTD3.py @@ -13,14 +13,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.PERTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import PERTD3Config class PERTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: PERTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/PPO.py b/cares_reinforcement_learning/algorithm/policy/PPO.py index 830a753..9f11151 100644 --- a/cares_reinforcement_learning/algorithm/policy/PPO.py +++ b/cares_reinforcement_learning/algorithm/policy/PPO.py @@ -18,14 +18,15 @@ from torch.distributions import MultivariateNormal from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.PPO import Actor, Critic from cares_reinforcement_learning.util.configurations import PPOConfig class PPO: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: PPOConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/RDSAC.py b/cares_reinforcement_learning/algorithm/policy/RDSAC.py index 1645eb4..c47551a 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/RDSAC.py @@ -9,14 +9,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.RDSAC import Actor, Critic from cares_reinforcement_learning.util.configurations import RDSACConfig class RDSAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: RDSACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/RDTD3.py b/cares_reinforcement_learning/algorithm/policy/RDTD3.py index aaf6ce1..d95774d 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/RDTD3.py @@ -10,14 +10,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.RDTD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import RDTD3Config class RDTD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: RDTD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/REDQ.py b/cares_reinforcement_learning/algorithm/policy/REDQ.py index 447b04e..dbe1891 100644 --- a/cares_reinforcement_learning/algorithm/policy/REDQ.py +++ b/cares_reinforcement_learning/algorithm/policy/REDQ.py @@ -13,14 +13,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.REDQ import Actor, EnsembleCritic from cares_reinforcement_learning.util.configurations import REDQConfig class REDQ: def __init__( self, - actor_network: torch.nn.Module, - ensemble_critics: torch.nn.ModuleList, + actor_network: Actor, + ensemble_critics: EnsembleCritic, config: REDQConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/SAC.py b/cares_reinforcement_learning/algorithm/policy/SAC.py index 6512163..6ea4b76 100644 --- a/cares_reinforcement_learning/algorithm/policy/SAC.py +++ b/cares_reinforcement_learning/algorithm/policy/SAC.py @@ -16,14 +16,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.SAC import Actor, Critic from cares_reinforcement_learning.util.configurations import SACConfig class SAC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: SACConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index 6a73138..6785e37 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -18,14 +18,15 @@ from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.SACAE import Actor, Critic from cares_reinforcement_learning.util.configurations import SACAEConfig class SACAE: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, decoder_network: Decoder, config: SACAEConfig, device: torch.device, diff --git a/cares_reinforcement_learning/algorithm/policy/SACD.py b/cares_reinforcement_learning/algorithm/policy/SACD.py index 40c5865..1cb78b8 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACD.py +++ b/cares_reinforcement_learning/algorithm/policy/SACD.py @@ -16,14 +16,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.SACD import Actor, Critic from cares_reinforcement_learning.util.configurations import SACDConfig class SACD: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: SACDConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/TD3.py b/cares_reinforcement_learning/algorithm/policy/TD3.py index 91779ec..be6b0ab 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3.py @@ -15,14 +15,15 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.TD3 import Actor, Critic from cares_reinforcement_learning.util.configurations import TD3Config class TD3: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: TD3Config, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/policy/TD3AE.py b/cares_reinforcement_learning/algorithm/policy/TD3AE.py index 12dae8b..d62cd0f 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3AE.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3AE.py @@ -16,14 +16,15 @@ from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.encoders.vanilla_autoencoder import Decoder from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.TD3AE import Actor, Critic from cares_reinforcement_learning.util.configurations import TD3AEConfig class TD3AE: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, decoder_network: Decoder, config: TD3AEConfig, device: torch.device, @@ -216,7 +217,9 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) hlp.soft_update_params( - self.actor_net.act_net, self.target_actor_net.act_net, self.encoder_tau + self.actor_net.actor.act_net, + self.target_actor_net.actor.act_net, + self.encoder_tau, ) hlp.soft_update_params( diff --git a/cares_reinforcement_learning/algorithm/policy/TQC.py b/cares_reinforcement_learning/algorithm/policy/TQC.py index 93e1d9d..df95170 100644 --- a/cares_reinforcement_learning/algorithm/policy/TQC.py +++ b/cares_reinforcement_learning/algorithm/policy/TQC.py @@ -14,6 +14,7 @@ import torch from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.TQC import Actor, Critic from cares_reinforcement_learning.util import helpers as hlp from cares_reinforcement_learning.util.configurations import TQCConfig @@ -21,8 +22,8 @@ class TQC: def __init__( self, - actor_network: torch.nn.Module, - critic_network: torch.nn.Module, + actor_network: Actor, + critic_network: Critic, config: TQCConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/value/DQN.py b/cares_reinforcement_learning/algorithm/value/DQN.py index 847a81d..7af3e49 100644 --- a/cares_reinforcement_learning/algorithm/value/DQN.py +++ b/cares_reinforcement_learning/algorithm/value/DQN.py @@ -11,13 +11,17 @@ import torch.nn.functional as F from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.DQN import Network as DQNNetwork +from cares_reinforcement_learning.networks.DuelingDQN import ( + Network as DuelingDQNNetwork, +) from cares_reinforcement_learning.util.configurations import DQNConfig, DuelingDQNConfig class DQN: def __init__( self, - network: torch.nn.Module, + network: DQNNetwork | DuelingDQNNetwork, config: DQNConfig | DuelingDQNConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/algorithm/value/DoubleDQN.py b/cares_reinforcement_learning/algorithm/value/DoubleDQN.py index db7425e..1555a70 100644 --- a/cares_reinforcement_learning/algorithm/value/DoubleDQN.py +++ b/cares_reinforcement_learning/algorithm/value/DoubleDQN.py @@ -14,13 +14,14 @@ import torch.nn.functional as F from cares_reinforcement_learning.memory import MemoryBuffer +from cares_reinforcement_learning.networks.DoubleDQN import Network from cares_reinforcement_learning.util.configurations import DoubleDQNConfig class DoubleDQN: def __init__( self, - network: torch.nn.Module, + network: Network, config: DoubleDQNConfig, device: torch.device, ): diff --git a/cares_reinforcement_learning/networks/MAPERTD3/critic.py b/cares_reinforcement_learning/networks/MAPERTD3/critic.py index c7c2efd..4ee4f63 100644 --- a/cares_reinforcement_learning/networks/MAPERTD3/critic.py +++ b/cares_reinforcement_learning/networks/MAPERTD3/critic.py @@ -1,4 +1,3 @@ -import torch from torch import nn from cares_reinforcement_learning.networks.TD3 import BaseCritic diff --git a/cares_reinforcement_learning/networks/RDSAC/critic.py b/cares_reinforcement_learning/networks/RDSAC/critic.py index 4ddd3ae..f1ab518 100644 --- a/cares_reinforcement_learning/networks/RDSAC/critic.py +++ b/cares_reinforcement_learning/networks/RDSAC/critic.py @@ -1,4 +1,3 @@ -import torch from torch import nn from cares_reinforcement_learning.networks.SAC import BaseCritic diff --git a/cares_reinforcement_learning/networks/RDTD3/critic.py b/cares_reinforcement_learning/networks/RDTD3/critic.py index 9b72ede..4cfb050 100644 --- a/cares_reinforcement_learning/networks/RDTD3/critic.py +++ b/cares_reinforcement_learning/networks/RDTD3/critic.py @@ -1,4 +1,3 @@ -import torch from torch import nn from cares_reinforcement_learning.networks.TD3 import BaseCritic