Skip to content

Commit

Permalink
Use switching function for Coulomb prior (#287)
Browse files Browse the repository at this point in the history
* Use switching function for Coulomb prior

* Updated documentation
  • Loading branch information
peastman authored Feb 20, 2024
1 parent 5a4fca7 commit 166b7db
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
11 changes: 6 additions & 5 deletions docs/source/priors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ It is possible to configure more than one prior in this way:

.. code:: yaml
prior_model:
Atomref: {} # No additional arguments
Coulomb:
alpha: 1
max_num_neighbors: 10
prior_model:
Atomref: {} # No additional arguments
Coulomb:
lower_switch_distance: 4
upper_switch_distance: 8
max_num_neighbors: 128
Expand Down
12 changes: 9 additions & 3 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,25 @@ def test_coulomb(dtype):
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
distance_scale = 1e-9 # Convert nm to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
alpha = 1.8
lower_switch_distance = 0.9
upper_switch_distance = 1.3

# Use the Coulomb class to compute the energy.

coulomb = Coulomb(alpha, 5, distance_scale=distance_scale, energy_scale=energy_scale)
coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale)
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
return torch.erf(alpha*r)*138.935*z1*z2/r
if r < lower_switch_distance:
return 0
energy = 138.935*z1*z2/r
if r < upper_switch_distance:
energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance))
return energy

expected = 0
for i in range(len(pos)):
Expand Down
22 changes: 14 additions & 8 deletions torchmdnet/priors/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from typing import Optional, Dict

class Coulomb(BasePrior):
"""This class implements a Coulomb potential, scaled by :math:`\\textrm{erf}(\\textrm{alpha}*r)` to reduce its
"""This class implements a Coulomb potential, scaled by a cosine switching function to reduce its
effect at short distances.
Parameters
----------
alpha : float
Scaling factor for the error function.
lower_switch_distance : float
distance below which the interaction strength is zero.
upper_switch_distance : float
distance above which the interaction has full strength
max_num_neighbors : int
Maximum number of neighbors per atom allowed.
distance_scale : float, optional
Expand All @@ -31,20 +33,22 @@ class Coulomb(BasePrior):
The Dataset used with this class must include a `partial_charges` field for each sample, and provide
`distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments.
"""
def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
def __init__(self, lower_switch_distance, upper_switch_distance, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
super(Coulomb, self).__init__()
if distance_scale is None:
distance_scale = dataset.distance_scale
if energy_scale is None:
energy_scale = dataset.energy_scale
self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors)
self.alpha = alpha
self.lower_switch_distance = lower_switch_distance
self.upper_switch_distance = upper_switch_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = float(distance_scale)
self.energy_scale = float(energy_scale)
self.initial_box = box_vecs
def get_init_args(self):
return {'alpha': self.alpha,
return {'lower_switch_distance': self.lower_switch_distance,
'upper_switch_distance': self.upper_switch_distance,
'max_num_neighbors': self.max_num_neighbors,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale,
Expand Down Expand Up @@ -78,14 +82,16 @@ def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extr
"""
# Convert to nm and calculate distance.
x = 1e9*self.distance_scale*pos
alpha = self.alpha/(1e9*self.distance_scale)
box = box if box is not None else self.initial_box
edge_index, distance, _ = self.distance(x, batch, box=box)

# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
q = extra_args['partial_charges'][edge_index]
energy = torch.erf(alpha*distance)*q[0]*q[1]/distance
lower = torch.tensor(self.lower_switch_distance)
upper = torch.tensor(self.upper_switch_distance)
phase = (torch.max(lower, torch.min(upper, distance))-lower)/(upper-lower)
energy = (0.5-0.5*torch.cos(torch.pi*phase))*q[0]*q[1]/distance
energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum")
energy = energy.reshape(y.shape)
return y + energy
2 changes: 1 addition & 1 deletion torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_argparse():
# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "alpha"=1}', action="extend", nargs="*")
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")

# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
Expand Down

0 comments on commit 166b7db

Please sign in to comment.