forked from facebookresearch/BenchMARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlp.py
157 lines (134 loc) · 5.69 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
from dataclasses import dataclass, MISSING
from typing import Optional, Sequence, Type
import torch
from tensordict import TensorDictBase
from torch import nn
from torchrl.modules import MLP, MultiAgentMLP
from benchmarl.models.common import Model, ModelConfig
class Mlp(Model):
"""Multi layer perceptron model.
Args:
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
layer_class (Type[nn.Module]): class to be used for the linear layers;
activation_class (Type[nn.Module]): activation class to be used.
activation_kwargs (dict, optional): kwargs to be used with the activation class;
norm_class (Type, optional): normalization class, if any.
norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
"""
def __init__(
self,
**kwargs,
):
super().__init__(
input_spec=kwargs.pop("input_spec"),
output_spec=kwargs.pop("output_spec"),
agent_group=kwargs.pop("agent_group"),
input_has_agent_dim=kwargs.pop("input_has_agent_dim"),
n_agents=kwargs.pop("n_agents"),
centralised=kwargs.pop("centralised"),
share_params=kwargs.pop("share_params"),
device=kwargs.pop("device"),
action_spec=kwargs.pop("action_spec"),
model_index=kwargs.pop("model_index"),
is_critic=kwargs.pop("is_critic"),
)
self.input_features = sum(
[spec.shape[-1] for spec in self.input_spec.values(True, True)]
)
self.output_features = self.output_leaf_spec.shape[-1]
if self.input_has_agent_dim:
self.mlp = MultiAgentMLP(
n_agent_inputs=self.input_features,
n_agent_outputs=self.output_features,
n_agents=self.n_agents,
centralised=self.centralised,
share_params=self.share_params,
device=self.device,
**kwargs,
)
else:
self.mlp = nn.ModuleList(
[
MLP(
in_features=self.input_features,
out_features=self.output_features,
device=self.device,
**kwargs,
)
for _ in range(self.n_agents if not self.share_params else 1)
]
)
def _perform_checks(self):
super()._perform_checks()
input_shape = None
for input_key, input_spec in self.input_spec.items(True, True):
if (self.input_has_agent_dim and len(input_spec.shape) == 2) or (
not self.input_has_agent_dim and len(input_spec.shape) == 1
):
if input_shape is None:
input_shape = input_spec.shape[:-1]
else:
if input_spec.shape[:-1] != input_shape:
raise ValueError(
f"MLP inputs should all have the same shape up to the last dimension, got {self.input_spec}"
)
else:
raise ValueError(
f"MLP input value {input_key} from {self.input_spec} has an invalid shape, maybe you need a CNN?"
)
if self.input_has_agent_dim:
if input_shape[-1] != self.n_agents:
raise ValueError(
"If the MLP input has the agent dimension,"
f" the second to last spec dimension should be the number of agents, got {self.input_spec}"
)
if (
self.output_has_agent_dim
and self.output_leaf_spec.shape[-2] != self.n_agents
):
raise ValueError(
"If the MLP output has the agent dimension,"
" the second to last spec dimension should be the number of agents"
)
def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather in_key
input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1)
# Has multi-agent input dimension
if self.input_has_agent_dim:
res = self.mlp.forward(input)
if not self.output_has_agent_dim:
# If we are here the module is centralised and parameter shared.
# Thus the multi-agent dimension has been expanded,
# We remove it without loss of data
res = res[..., 0, :]
# Does not have multi-agent input dimension
else:
if not self.share_params:
res = torch.stack(
[net(input) for net in self.mlp],
dim=-2,
)
else:
res = self.mlp[0](input)
tensordict.set(self.out_key, res)
return tensordict
@dataclass
class MlpConfig(ModelConfig):
"""Dataclass config for a :class:`~benchmarl.models.Mlp`."""
num_cells: Sequence[int] = MISSING
layer_class: Type[nn.Module] = MISSING
activation_class: Type[nn.Module] = MISSING
activation_kwargs: Optional[dict] = None
norm_class: Type[nn.Module] = None
norm_kwargs: Optional[dict] = None
@staticmethod
def associated_class():
return Mlp