diff --git a/examples/TensorNet-ANI2X.yaml b/examples/TensorNet-ANI2X.yaml new file mode 100644 index 000000000..a82f1d446 --- /dev/null +++ b/examples/TensorNet-ANI2X.yaml @@ -0,0 +1,58 @@ +activation: silu +aggr: add +atom_filter: -1 +batch_size: 64 +coord_files: null +cutoff_lower: 0.0 +cutoff_upper: 5.0 +dataset: ANI2X +dataset_root: ~/data +derivative: true +distance_influence: both +early_stopping_patience: 50 +ema_alpha_neg_dy: 1.0 +ema_alpha_y: 1.0 +embed_files: null +embedding_dimension: 128 +energy_files: null +equivariance_invariance_group: O(3) +y_weight: 1 +force_files: null +neg_dy_weight: 100 +gradient_clipping: 100.0 +inference_batch_size: 64 +load_model: null +log_dir: logs/ +lr: 1.0e-3 +lr_factor: 0.5 +lr_min: 1.0e-07 +lr_patience: 4 +lr_warmup_steps: 1000 +max_num_neighbors: 128 +max_z: 128 +model: tensornet +ngpus: -1 +num_epochs: 500 +num_layers: 2 +num_nodes: 1 +num_rbf: 32 +num_workers: 4 +output_model: Scalar +precision: 32 +prior_model: null +rbf_type: expnorm +redirect: false +reduce_op: add +save_interval: 10 +splits: null +seed: 1 +standardize: false +test_interval: -1 +test_size: null +train_size: 0.8 +trainable_rbf: false +val_size: 0.1 +weight_decay: 0.0 +charge: false +spin: false +tensorboard_use: true diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index 50c1462a1..42dc59dec 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) from .ace import Ace -from .ani import ANI1, ANI1CCX, ANI1X +from .ani import ANI1, ANI1CCX, ANI1X, ANI2X from .comp6 import ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8, COMP6v1 from .custom import Custom from .water import WaterBox @@ -20,6 +20,7 @@ "ANI1", "ANI1CCX", "ANI1X", + "ANI2X", "COMP6v1", "Custom", "DrugBank", diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index 25e3df7bf..4149f0bf4 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -34,7 +34,7 @@ class ANIBase(Dataset): - Smith, J. S., Zubatyuk, R., Nebgen, B., Lubbers, N., Barros, K., Roitberg, A. E., Isayev, O., & Tretiak, S. (2020). The ANI-1ccx and ANI-1x data sets, coupled-cluster and density functional theory properties for molecules. Scientific Data, 7, Article 134. """ - HARTREE_TO_EV = 27.211386246 #::meta private: + HARTREE_TO_EV = 27.211386246 #::meta private: @property def raw_url(self): @@ -95,7 +95,6 @@ def processed_file_names(self): ] def filter_and_pre_transform(self, data): - if self.pre_filter is not None and not self.pre_filter(data): return None @@ -132,7 +131,10 @@ def process(self): ) neg_dy_mm = ( np.memmap( - neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_name + ".tmp", + mode="w+", + dtype=np.float32, + shape=(num_all_atoms, 3), ) if has_neg_dy else open(neg_dy_name, "w") @@ -206,12 +208,12 @@ def get(self, idx): class ANI1(ANIBase): __doc__ = ANIBase.__doc__ # Avoid sphinx from documenting this - ELEMENT_ENERGIES = { + _ELEMENT_ENERGIES = { 1: -0.500607632585, 6: -37.8302333826, 7: -54.5680045287, 8: -75.0362229210, - } #::meta private: + } #::meta private: @property def raw_url(self): @@ -229,7 +231,6 @@ def download(self): os.remove(archive) def sample_iter(self, mol_ids=False): - atomic_numbers = {b"H": 1, b"C": 6, b"N": 7, b"O": 8} for path in tqdm(self.raw_paths, desc="Files"): @@ -249,7 +250,6 @@ def sample_iter(self, mol_ids=False): assert all_pos.shape[2] == 3 for pos, y in zip(all_pos, all_y): - # Create a sample args = dict(z=z, pos=pos, y=y.view(1, 1)) if mol_ids: @@ -260,8 +260,7 @@ def sample_iter(self, mol_ids=False): yield data def get_atomref(self, max_z=100): - """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - """ + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.""" refs = pt.zeros(max_z) refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C @@ -277,7 +276,6 @@ def process(self): class ANI1XBase(ANIBase): - @property def raw_url(self): return "https://figshare.com/ndownloader/files/18112775" @@ -292,8 +290,7 @@ def download(self): os.rename(file, self.raw_paths[0]) def get_atomref(self, max_z=100): - """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - """ + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.""" warnings.warn("Atomic references from the ANI-1 dataset are used!") refs = pt.zeros(max_z) @@ -307,7 +304,7 @@ def get_atomref(self, max_z=100): class ANI1X(ANI1XBase): __doc__ = ANIBase.__doc__ - ELEMENT_ENERGIES = { + _ELEMENT_ENERGIES = { 1: -0.500607632585, 6: -37.8302333826, 7: -54.5680045287, @@ -317,13 +314,12 @@ class ANI1X(ANI1XBase): :meta private: """ - def sample_iter(self, mol_ids=False): + def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 with h5py.File(self.raw_paths[0]) as h5: for mol_id, mol in tqdm(h5.items(), desc="Molecules"): - z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long) all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32) all_y = pt.tensor( @@ -342,7 +338,6 @@ def sample_iter(self, mol_ids=False): assert all_neg_dy.shape[2] == 3 for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy): - if y.isnan() or neg_dy.isnan().any(): continue @@ -368,13 +363,12 @@ def process(self): class ANI1CCX(ANI1XBase): __doc__ = ANIBase.__doc__ - def sample_iter(self, mol_ids=False): + def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 with h5py.File(self.raw_paths[0]) as h5: for mol_id, mol in tqdm(h5.items(), desc="Molecules"): - z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long) all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32) all_y = pt.tensor( @@ -386,7 +380,6 @@ def sample_iter(self, mol_ids=False): assert all_pos.shape[2] == 3 for pos, y in zip(all_pos, all_y): - if y.isnan(): continue @@ -408,3 +401,84 @@ def download(self): # TODO remove when fixed def process(self): super().process() + + +class ANI2X(ANIBase): + __doc__ = ANIBase.__doc__ + + # Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + _ELEMENT_ENERGIES = { + 1: -0.5978583943827134, # H + 6: -38.08933878049795, # C + 7: -54.711968298621066, # N + 8: -75.19106774742086, # O + 9: -99.80348506781634, # F + 16: -398.1577125334925, # S + 17: -460.1681939421027, # Cl + } + + @property + def raw_url(self): + return "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz" + + @property + def raw_file_names(self): + return [os.path.join("final_h5", "ANI-2x-wB97X-631Gd.h5")] + + def download(self): + archive = download_url(self.raw_url, self.raw_dir) + extract_tar(archive, self.raw_dir) + os.remove(archive) + + def sample_iter(self, mol_ids=False): + """ + In [15]: list(molecules) + Out[15]: + [('coordinates', ), + ('energies', ), + ('forces', ), + ('species', )] + """ + assert len(self.raw_paths) == 1 + with h5py.File(self.raw_paths[0]) as h5data: + for key, data in tqdm(h5data.items(), desc="Molecule Group", leave=False): + all_z = pt.tensor(data["species"][:], dtype=pt.long) + all_pos = pt.tensor(data["coordinates"][:], dtype=pt.float32) + all_y = pt.tensor( + data["energies"][:] * self.HARTREE_TO_EV, dtype=pt.float64 + ) + all_neg_dy = pt.tensor( + data["forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32 + ) + n_mols = all_pos.shape[0] + n_atoms = all_pos.shape[1] + + assert all_y.shape[0] == n_mols + assert all_z.shape == (n_mols, n_atoms) + assert all_pos.shape == (n_mols, n_atoms, 3) + assert all_neg_dy.shape == (n_mols, n_atoms, 3) + + for i, (pos, y, z, neg_dy) in enumerate( + zip(all_pos, all_y, all_z, all_neg_dy) + ): + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy) + if mol_ids: + args["mol_id"] = f"{key}_{i}" + data = Data(**args) + + if data := self.filter_and_pre_transform(data): + yield data + + def get_atomref(self, max_z=100): + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.""" + refs = pt.zeros(max_z) + for key, val in self._ELEMENT_ENERGIES.items(): + refs[key] = val * self.HARTREE_TO_EV + + return refs.view(-1, 1) + + # Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567 + # TODO remove when fixed + def process(self): + super().process() diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 04e22758d..519a3f090 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -24,6 +24,7 @@ from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number from lightning_utilities.core.rank_zero import rank_zero_warn + def get_argparse(): # fmt: off parser = argparse.ArgumentParser(description='Training') @@ -119,8 +120,8 @@ def get_argparse(): # fmt: on return parser -def get_args(): +def get_args(): parser = get_argparse() args = parser.parse_args() if args.redirect: @@ -133,6 +134,7 @@ def get_args(): if args.inference_batch_size is None: args.inference_batch_size = args.batch_size + os.makedirs(os.path.abspath(args.log_dir), exist_ok=True) save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) return args @@ -179,7 +181,7 @@ def main(): if args.tensorboard_use: tb_logger = TensorBoardLogger( - args.log_dir, name="tensorbord", version="", default_hp_metric=False + args.log_dir, name="tensorboard", version="", default_hp_metric=False ) _logger.append(tb_logger) if args.test_interval > 0: