Skip to content

Commit

Permalink
Merge pull request #252 from torchmd/ani2x_dataloader
Browse files Browse the repository at this point in the history
Ani-2x dataloader
  • Loading branch information
RaulPPelaez authored Jan 18, 2024
2 parents 93d3d8b + f0f5904 commit 2c2b5f0
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 22 deletions.
58 changes: 58 additions & 0 deletions examples/TensorNet-ANI2X.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
activation: silu
aggr: add
atom_filter: -1
batch_size: 64
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: ANI2X
dataset_root: ~/data
derivative: true
distance_influence: both
early_stopping_patience: 50
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 128
energy_files: null
equivariance_invariance_group: O(3)
y_weight: 1
force_files: null
neg_dy_weight: 100
gradient_clipping: 100.0
inference_batch_size: 64
load_model: null
log_dir: logs/
lr: 1.0e-3
lr_factor: 0.5
lr_min: 1.0e-07
lr_patience: 4
lr_warmup_steps: 1000
max_num_neighbors: 128
max_z: 128
model: tensornet
ngpus: -1
num_epochs: 500
num_layers: 2
num_nodes: 1
num_rbf: 32
num_workers: 4
output_model: Scalar
precision: 32
prior_model: null
rbf_type: expnorm
redirect: false
reduce_op: add
save_interval: 10
splits: null
seed: 1
standardize: false
test_interval: -1
test_size: null
train_size: 0.8
trainable_rbf: false
val_size: 0.1
weight_decay: 0.0
charge: false
spin: false
tensorboard_use: true
3 changes: 2 additions & 1 deletion torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from .ace import Ace
from .ani import ANI1, ANI1CCX, ANI1X
from .ani import ANI1, ANI1CCX, ANI1X, ANI2X
from .comp6 import ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8, COMP6v1
from .custom import Custom
from .water import WaterBox
Expand All @@ -20,6 +20,7 @@
"ANI1",
"ANI1CCX",
"ANI1X",
"ANI2X",
"COMP6v1",
"Custom",
"DrugBank",
Expand Down
112 changes: 93 additions & 19 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ANIBase(Dataset):
- Smith, J. S., Zubatyuk, R., Nebgen, B., Lubbers, N., Barros, K., Roitberg, A. E., Isayev, O., & Tretiak, S. (2020). The ANI-1ccx and ANI-1x data sets, coupled-cluster and density functional theory properties for molecules. Scientific Data, 7, Article 134.
"""

HARTREE_TO_EV = 27.211386246 #::meta private:
HARTREE_TO_EV = 27.211386246 #::meta private:

@property
def raw_url(self):
Expand Down Expand Up @@ -95,7 +95,6 @@ def processed_file_names(self):
]

def filter_and_pre_transform(self, data):

if self.pre_filter is not None and not self.pre_filter(data):
return None

Expand Down Expand Up @@ -132,7 +131,10 @@ def process(self):
)
neg_dy_mm = (
np.memmap(
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
neg_dy_name + ".tmp",
mode="w+",
dtype=np.float32,
shape=(num_all_atoms, 3),
)
if has_neg_dy
else open(neg_dy_name, "w")
Expand Down Expand Up @@ -206,12 +208,12 @@ def get(self, idx):
class ANI1(ANIBase):
__doc__ = ANIBase.__doc__
# Avoid sphinx from documenting this
ELEMENT_ENERGIES = {
_ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
8: -75.0362229210,
} #::meta private:
} #::meta private:

@property
def raw_url(self):
Expand All @@ -229,7 +231,6 @@ def download(self):
os.remove(archive)

def sample_iter(self, mol_ids=False):

atomic_numbers = {b"H": 1, b"C": 6, b"N": 7, b"O": 8}

for path in tqdm(self.raw_paths, desc="Files"):
Expand All @@ -249,7 +250,6 @@ def sample_iter(self, mol_ids=False):
assert all_pos.shape[2] == 3

for pos, y in zip(all_pos, all_y):

# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1))
if mol_ids:
Expand All @@ -260,8 +260,7 @@ def sample_iter(self, mol_ids=False):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
"""
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C
Expand All @@ -277,7 +276,6 @@ def process(self):


class ANI1XBase(ANIBase):

@property
def raw_url(self):
return "https://figshare.com/ndownloader/files/18112775"
Expand All @@ -292,8 +290,7 @@ def download(self):
os.rename(file, self.raw_paths[0])

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
"""
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
warnings.warn("Atomic references from the ANI-1 dataset are used!")

refs = pt.zeros(max_z)
Expand All @@ -307,7 +304,7 @@ def get_atomref(self, max_z=100):

class ANI1X(ANI1XBase):
__doc__ = ANIBase.__doc__
ELEMENT_ENERGIES = {
_ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
Expand All @@ -317,13 +314,12 @@ class ANI1X(ANI1XBase):
:meta private:
"""
def sample_iter(self, mol_ids=False):

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1

with h5py.File(self.raw_paths[0]) as h5:
for mol_id, mol in tqdm(h5.items(), desc="Molecules"):

z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long)
all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
Expand All @@ -342,7 +338,6 @@ def sample_iter(self, mol_ids=False):
assert all_neg_dy.shape[2] == 3

for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy):

if y.isnan() or neg_dy.isnan().any():
continue

Expand All @@ -368,13 +363,12 @@ def process(self):

class ANI1CCX(ANI1XBase):
__doc__ = ANIBase.__doc__
def sample_iter(self, mol_ids=False):

def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1

with h5py.File(self.raw_paths[0]) as h5:
for mol_id, mol in tqdm(h5.items(), desc="Molecules"):

z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long)
all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
Expand All @@ -386,7 +380,6 @@ def sample_iter(self, mol_ids=False):
assert all_pos.shape[2] == 3

for pos, y in zip(all_pos, all_y):

if y.isnan():
continue

Expand All @@ -408,3 +401,84 @@ def download(self):
# TODO remove when fixed
def process(self):
super().process()


class ANI2X(ANIBase):
__doc__ = ANIBase.__doc__

# Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat
_ELEMENT_ENERGIES = {
1: -0.5978583943827134, # H
6: -38.08933878049795, # C
7: -54.711968298621066, # N
8: -75.19106774742086, # O
9: -99.80348506781634, # F
16: -398.1577125334925, # S
17: -460.1681939421027, # Cl
}

@property
def raw_url(self):
return "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz"

@property
def raw_file_names(self):
return [os.path.join("final_h5", "ANI-2x-wB97X-631Gd.h5")]

def download(self):
archive = download_url(self.raw_url, self.raw_dir)
extract_tar(archive, self.raw_dir)
os.remove(archive)

def sample_iter(self, mol_ids=False):
"""
In [15]: list(molecules)
Out[15]:
[('coordinates', <HDF5 dataset "coordinates": shape (5706, 2, 3), type "<f4">),
('energies', <HDF5 dataset "energies": shape (5706,), type "<f8">),
('forces', <HDF5 dataset "forces": shape (5706, 2, 3), type "<f8">),
('species', <HDF5 dataset "species": shape (5706, 2), type "<i8">)]
"""
assert len(self.raw_paths) == 1
with h5py.File(self.raw_paths[0]) as h5data:
for key, data in tqdm(h5data.items(), desc="Molecule Group", leave=False):
all_z = pt.tensor(data["species"][:], dtype=pt.long)
all_pos = pt.tensor(data["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
data["energies"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)
all_neg_dy = pt.tensor(
data["forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32
)
n_mols = all_pos.shape[0]
n_atoms = all_pos.shape[1]

assert all_y.shape[0] == n_mols
assert all_z.shape == (n_mols, n_atoms)
assert all_pos.shape == (n_mols, n_atoms, 3)
assert all_neg_dy.shape == (n_mols, n_atoms, 3)

for i, (pos, y, z, neg_dy) in enumerate(
zip(all_pos, all_y, all_z, all_neg_dy)
):
# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy)
if mol_ids:
args["mol_id"] = f"{key}_{i}"
data = Data(**args)

if data := self.filter_and_pre_transform(data):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()
6 changes: 4 additions & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
from lightning_utilities.core.rank_zero import rank_zero_warn


def get_argparse():
# fmt: off
parser = argparse.ArgumentParser(description='Training')
Expand Down Expand Up @@ -119,8 +120,8 @@ def get_argparse():
# fmt: on
return parser

def get_args():

def get_args():
parser = get_argparse()
args = parser.parse_args()
if args.redirect:
Expand All @@ -133,6 +134,7 @@ def get_args():
if args.inference_batch_size is None:
args.inference_batch_size = args.batch_size

os.makedirs(os.path.abspath(args.log_dir), exist_ok=True)
save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])

return args
Expand Down Expand Up @@ -179,7 +181,7 @@ def main():

if args.tensorboard_use:
tb_logger = TensorBoardLogger(
args.log_dir, name="tensorbord", version="", default_hp_metric=False
args.log_dir, name="tensorboard", version="", default_hp_metric=False
)
_logger.append(tb_logger)
if args.test_interval > 0:
Expand Down

0 comments on commit 2c2b5f0

Please sign in to comment.