Skip to content

Commit

Permalink
CrossQ Base, Default, Custom - not MLP yet as different layout of bat…
Browse files Browse the repository at this point in the history
…chnorm
  • Loading branch information
beardyFace committed Dec 2, 2024
1 parent 91c3532 commit 2d54abb
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 62 deletions.
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/networks/CrossQ/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .actor import Actor
from .actor import Actor, DefaultActor
from .critic import Critic
107 changes: 83 additions & 24 deletions cares_reinforcement_learning/networks/CrossQ/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,29 @@
from torch import nn

from cares_reinforcement_learning.util.common import SquashedNormal
from cares_reinforcement_learning.util.configurations import CrossQConfig


class Actor(nn.Module):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""

class BaseActor(nn.Module):
def __init__(
self,
observation_size: int,
act_net: nn.Module,
mean_linear: nn.Linear,
log_std_linear: nn.Linear,
num_actions: int,
hidden_size: list[int] = None,
log_std_bounds: list[int] = None,
log_std_bounds: list[float] | None = None,
):
super().__init__()
if hidden_size is None:
hidden_size = [256, 256]
if log_std_bounds is None:
log_std_bounds = [-5, 2]
log_std_bounds = [-20, 2]

self.num_actions = num_actions
self.hidden_size = hidden_size
self.log_std_bounds = log_std_bounds

momentum = 0.01
self.act_net = nn.Sequential(
BatchRenorm1d(observation_size, momentum=momentum),
nn.Linear(observation_size, self.hidden_size[0], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[0], momentum=momentum),
nn.Linear(self.hidden_size[0], self.hidden_size[1], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[1], momentum=momentum),
)
self.num_actions = num_actions
self.act_net = act_net

self.mean_linear = nn.Linear(self.hidden_size[1], num_actions)
self.log_std_linear = nn.Linear(self.hidden_size[1], num_actions)
self.mean_linear = mean_linear
self.log_std_linear = log_std_linear

def forward(
self, state: torch.Tensor
Expand All @@ -64,3 +51,75 @@ def forward(
log_pi = dist.log_prob(sample).sum(-1, keepdim=True)

return sample, log_pi, dist.mean


class DefaultActor(BaseActor):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""

def __init__(
self,
observation_size: int,
num_actions: int,
):

hidden_sizes = [256, 256]
log_std_bounds = [-20.0, 2.0]

momentum = 0.01
act_net = nn.Sequential(
BatchRenorm1d(observation_size, momentum=momentum),
nn.Linear(observation_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
)

mean_linear = nn.Linear(hidden_sizes[1], num_actions)
log_std_linear = nn.Linear(hidden_sizes[1], num_actions)

super().__init__(
act_net=act_net,
mean_linear=mean_linear,
log_std_linear=log_std_linear,
num_actions=num_actions,
log_std_bounds=log_std_bounds,
)


class Actor(BaseActor):
# DiagGaussianActor
"""torch.distributions implementation of an diagonal Gaussian policy."""

def __init__(
self,
observation_size: int,
num_actions: int,
config: CrossQConfig,
):
hidden_sizes = config.hidden_size_actor
log_std_bounds = config.log_std_bounds

momentum = 0.01
act_net = nn.Sequential(
BatchRenorm1d(observation_size, momentum=momentum),
nn.Linear(observation_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
)

mean_linear = nn.Linear(hidden_sizes[1], num_actions)
log_std_linear = nn.Linear(hidden_sizes[1], num_actions)

super().__init__(
act_net=act_net,
mean_linear=mean_linear,
log_std_linear=log_std_linear,
num_actions=num_actions,
log_std_bounds=log_std_bounds,
)
113 changes: 84 additions & 29 deletions cares_reinforcement_learning/networks/CrossQ/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,107 @@
from batchrenorm import BatchRenorm1d
from torch import nn

from cares_reinforcement_learning.util.configurations import CrossQConfig

class Critic(nn.Module):

class BaseCritic(nn.Module):
def __init__(self, Q1: nn.Module, Q2: nn.Module):
super().__init__()

# Q1 architecture
# pylint: disable-next=invalid-name
self.Q1 = Q1

# Q2 architecture
# pylint: disable-next=invalid-name
self.Q2 = Q2

def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
obs_action = torch.cat([state, action], dim=1)
q1 = self.Q1(obs_action)
q2 = self.Q2(obs_action)
return q1, q2


class DefaultCritic(BaseCritic):
def __init__(
self,
observation_size: int,
num_actions: int,
hidden_size: list[int] = None,
hidden_sizes: list[int] | None = None,
):
super().__init__()
if hidden_size is None:
hidden_size = [2048, 2048]
if hidden_sizes is None:
hidden_sizes = [2048, 2048]

self.hidden_size = hidden_size
self.input_size = observation_size + num_actions
input_size = observation_size + num_actions

# Q1 architecture
# pylint: disable-next=invalid-name
momentum = 0.1
self.Q1 = nn.Sequential(
BatchRenorm1d(self.input_size, momentum=momentum),
nn.Linear(self.input_size, self.hidden_size[0], bias=False),
momentum = 0.01
Q1 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[0], momentum=momentum),
nn.Linear(self.hidden_size[0], self.hidden_size[1], bias=False),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[1], momentum=momentum),
nn.Linear(self.hidden_size[1], 1),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)

# Q2 architecture
# pylint: disable-next=invalid-name
self.Q2 = nn.Sequential(
BatchRenorm1d(self.input_size, momentum=momentum),
nn.Linear(self.input_size, self.hidden_size[0], bias=False),
Q2 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[0], momentum=momentum),
nn.Linear(self.hidden_size[0], self.hidden_size[1], bias=False),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(self.hidden_size[1], momentum=momentum),
nn.Linear(self.hidden_size[1], 1),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)

def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
obs_action = torch.cat([state, action], dim=1)
q1 = self.Q1(obs_action)
q2 = self.Q2(obs_action)
return q1, q2
super().__init__(Q1=Q1, Q2=Q2)


class Critic(BaseCritic):
def __init__(
self,
observation_size: int,
num_actions: int,
config: CrossQConfig,
):
input_size = observation_size + num_actions
hidden_sizes = config.hidden_size_critic

# Q1 architecture
# pylint: disable-next=invalid-name
momentum = 0.01
Q1 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)

# Q2 architecture
# pylint: disable-next=invalid-name
Q2 = nn.Sequential(
BatchRenorm1d(input_size, momentum=momentum),
nn.Linear(input_size, hidden_sizes[0], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[0], momentum=momentum),
nn.Linear(hidden_sizes[0], hidden_sizes[1], bias=False),
nn.ReLU(),
BatchRenorm1d(hidden_sizes[1], momentum=momentum),
nn.Linear(hidden_sizes[1], 1),
)

super().__init__(Q1=Q1, Q2=Q2)
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class CrossQConfig(AlgorithmConfig):
gamma: float = 0.99
reward_scale: float = 1.0

log_std_bounds: list[float] = [-5, 2]
log_std_bounds: list[float] = [-20, 2]

policy_update_freq: int = 3

Expand Down
9 changes: 2 additions & 7 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,8 @@ def create_CrossQ(observation_size, action_num, config: acf.CrossQConfig):
from cares_reinforcement_learning.algorithm.policy import CrossQ
from cares_reinforcement_learning.networks.CrossQ import Actor, Critic

actor = Actor(
observation_size,
action_num,
hidden_size=config.hidden_size_actor,
log_std_bounds=config.log_std_bounds,
)
critic = Critic(observation_size, action_num, hidden_size=config.hidden_size_critic)
actor = Actor(observation_size, action_num, config=config)
critic = Critic(observation_size, action_num, config=config)

device = hlp.get_device()
agent = CrossQ(
Expand Down

0 comments on commit 2d54abb

Please sign in to comment.