Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Introduce additional_labels #306

Open
wants to merge 83 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
43d8f9f
Update tensornet.py
guillemsimeon Mar 5, 2024
9ccc2c0
Update model.py
guillemsimeon Mar 5, 2024
25854fb
Update model.py
guillemsimeon Mar 5, 2024
cc6de7a
move to extra_fields implementation
AntonioMirarchi Mar 6, 2024
9478a92
remove charge and spin, this will go to extra_fields
AntonioMirarchi Mar 6, 2024
c7014e7
force hdf5 dataset to common-unique data structure
AntonioMirarchi Mar 6, 2024
1f4928c
force ace dataset to common-unique data structure
AntonioMirarchi Mar 6, 2024
7f238dc
remove charge and spin flag from train.py
AntonioMirarchi Mar 6, 2024
3250fc4
add extra_fields to argparse
AntonioMirarchi Mar 6, 2024
736af54
memdataset to common-unique data-structure
AntonioMirarchi Mar 6, 2024
d495044
qm9q to common-unique data structure
AntonioMirarchi Mar 6, 2024
014a06c
ET to extra_fields
AntonioMirarchi Mar 6, 2024
b081a09
add extra_fields documentation to the ET
AntonioMirarchi Mar 6, 2024
2910694
transformer to extra_fields
AntonioMirarchi Mar 6, 2024
c0489af
small fix in ET documentation
AntonioMirarchi Mar 6, 2024
cb2229a
graph-network to extra_fields
AntonioMirarchi Mar 6, 2024
2fe05bc
remove optional tensor for extra args, it's needed by default
AntonioMirarchi Mar 6, 2024
3150672
remove extra_args from model forward, extra_fields_args it's only needed
AntonioMirarchi Mar 6, 2024
6055357
add self.extra_fields to architectures
AntonioMirarchi Mar 6, 2024
e3ebbbe
remove all 'q' specific function, to move to more general extra_fields
AntonioMirarchi Mar 6, 2024
ea3acd6
change variable name t additional_labels and allow to be only a dict
AntonioMirarchi Mar 12, 2024
c34fc52
remove architectural redundancy
AntonioMirarchi Mar 12, 2024
32f2456
move to additional_labels verion
AntonioMirarchi Mar 12, 2024
ef4cfbe
use extra_args
AntonioMirarchi Mar 12, 2024
ebb08c2
tnsnet v2 with tensornetQ class as additional method
AntonioMirarchi Mar 12, 2024
91c7cde
update warning message
AntonioMirarchi Mar 12, 2024
fbfafaf
force labels to be atom_wise
AntonioMirarchi Mar 12, 2024
088dd5f
remove unused arg
AntonioMirarchi Mar 12, 2024
9479ae7
fix extra_args_nnp generation
AntonioMirarchi Mar 12, 2024
cd8ab2d
remove old code residue
AntonioMirarchi Mar 12, 2024
f48e408
use correct name in forward for extra_args
AntonioMirarchi Mar 12, 2024
ef64fc4
fix arg name
AntonioMirarchi Mar 12, 2024
de17cc0
fix arg name
AntonioMirarchi Mar 12, 2024
f947af4
rename to additional_methods
AntonioMirarchi Mar 12, 2024
4e00f8c
fix documentation
AntonioMirarchi Mar 13, 2024
2206a4c
fix ace dataloader with new extra_args name
AntonioMirarchi Mar 13, 2024
582d1ef
prefactor to device and dtype
AntonioMirarchi Mar 13, 2024
7547f34
initialize nn.Parameter with torch tensor
AntonioMirarchi Mar 13, 2024
2388670
fix argspace name for additional labels
AntonioMirarchi Mar 13, 2024
ea56bd3
fix prefactor operation
AntonioMirarchi Mar 13, 2024
89b0d31
more efficient
AntonioMirarchi Mar 14, 2024
10ee140
specify also the name of the mehods in addtional_methods dict
AntonioMirarchi Mar 14, 2024
cef4e64
remove unused trainable_rbf from tensornet embedding
AntonioMirarchi Mar 14, 2024
8958750
update additional_labels documentation in models
AntonioMirarchi Mar 14, 2024
4ef5aa9
remove extra_args expansion
AntonioMirarchi Mar 14, 2024
815b8c3
add extra_args expansion inside the model
AntonioMirarchi Mar 14, 2024
3cb232e
add test for additional labels
AntonioMirarchi Mar 15, 2024
2c8e47b
add additional_labels to load_example_args for testing
AntonioMirarchi Mar 15, 2024
6f7fac9
double check with and
AntonioMirarchi Mar 15, 2024
e1d8918
fix condition when extra args are passed to the forward
AntonioMirarchi Mar 15, 2024
40234f2
update to addtional_labels
AntonioMirarchi Mar 15, 2024
3b0dbf8
update test_examples
AntonioMirarchi Mar 15, 2024
433017e
update test wrappers
AntonioMirarchi Mar 15, 2024
d3af9dd
additional_methods to torchmd_GN_optimized
AntonioMirarchi Mar 15, 2024
4f3903a
small change, remove print from test_model
AntonioMirarchi Mar 15, 2024
2b18f90
to shared extra args nomenclature
AntonioMirarchi Mar 15, 2024
604034a
fix dipole_moments in the documentation, ace v2.0
AntonioMirarchi Mar 18, 2024
052261e
remove old comment
AntonioMirarchi Mar 18, 2024
c0902c5
update to get the additional labels from argparse as discussed in the PR
AntonioMirarchi Mar 18, 2024
917cb7a
fix documentation, include extra_args in the model's input
AntonioMirarchi Mar 18, 2024
452362b
to black
AntonioMirarchi Mar 18, 2024
cfd0f53
remove Any from typyng import
AntonioMirarchi Mar 18, 2024
5e02fc7
initialize_additional_method as free standing function
AntonioMirarchi Mar 18, 2024
3470f7e
remove Any from typing import because not used
AntonioMirarchi Mar 18, 2024
2ec850a
remove Any from typing import
AntonioMirarchi Mar 18, 2024
781a3a7
update for loop in the forward step and extra_args/static_shapes asse…
AntonioMirarchi Mar 18, 2024
1a30039
send always extra_args
AntonioMirarchi Mar 18, 2024
e811a5e
add reset_parameters for tensornetQ
AntonioMirarchi Mar 18, 2024
27f13db
add an assertion to verify that additional_labels is not specified fo…
AntonioMirarchi Mar 18, 2024
3ee76de
reintroduce charge and spin for backward compatibility
AntonioMirarchi Mar 18, 2024
d5a9e5f
to black
AntonioMirarchi Mar 18, 2024
75c6726
remove comment
AntonioMirarchi Mar 18, 2024
591253c
let tensornetq's operations more efficient
AntonioMirarchi Mar 18, 2024
4cd6313
move to label_callbacks instead of additional_methods
AntonioMirarchi Mar 18, 2024
6b72a00
rename to additional_labels_handler
AntonioMirarchi Mar 18, 2024
fc58fb3
get torch.jit compatibility, comprehension ifs are not supported yet
AntonioMirarchi Mar 18, 2024
0ac63d9
remove unused import
AntonioMirarchi Mar 18, 2024
8a4a181
use warning instead of assertion
AntonioMirarchi Mar 18, 2024
fdf12aa
more readable format
AntonioMirarchi Mar 18, 2024
580cd34
update test_model considering extra_args will be always passed to the…
AntonioMirarchi Mar 18, 2024
d483230
move prefactor and tensornet_q operation from the interacton layer to…
AntonioMirarchi Mar 19, 2024
9e546dd
fix typo
AntonioMirarchi Mar 19, 2024
b0d08e1
Merge branch 'main' into extra_fields_NNPs
AntonioMirarchi Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import mark
import yaml
import glob
import torch
from os.path import dirname, join
from torchmdnet.models.model import create_model
from torchmdnet import priors
Expand All @@ -27,4 +28,4 @@ def test_example_yamls(fname):

z, pos, batch = create_example_batch()
model(z, pos, batch)
model(z, pos, batch, q=None, s=None)
AntonioMirarchi marked this conversation as resolved.
Show resolved Hide resolved
model(z, pos, batch, extra_args={"total_charge": torch.zeros_like(z)})
17 changes: 8 additions & 9 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_batch", [True, False])
@mark.parametrize("explicit_q_s", [True, False])
@mark.parametrize("use_extra_args", [True, False])
@mark.parametrize("precision", [32, 64])
def test_forward(model_name, use_batch, explicit_q_s, precision):
@mark.parametrize("additional_labels", [None, {"tensornet_q": {"label": "total_charge", 'learnable': False, 'init_value': 0.1}}])
def test_forward(model_name, use_batch, use_extra_args, precision, additional_labels):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
model = create_model(load_example_args(model_name, prior_model=None, precision=precision, additional_labels=additional_labels))
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
else:
if not use_extra_args and additional_labels is None:
model(z, pos, batch=batch)
else:
model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)})



@mark.parametrize("model_name", models.__all_models__)
Expand Down Expand Up @@ -137,7 +137,6 @@ def test_cuda_graph_compatible(model_name):
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_atom_filter(remove_threshold, model_name):
model = AtomFilter(model, remove_threshold)

z, pos, batch = create_example_batch(n_atoms=100)
x, v, z, pos, batch = model(z, pos, batch, None, None)
x, v, z, pos, batch = model(z, pos, batch, None)

assert (z > remove_threshold).all(), (
f"Lowest updated atomic number is {z.min()} but "
Expand Down
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
args["box_vecs"] = None
if "remove_ref_energy" not in args:
args["remove_ref_energy"] = False
if "additional_labels" not in args:
args["additional_labels"] = None
for key, val in kwargs.items():
assert key in args, f"Broken test! Unknown key '{key}'."
args[key] = val
Expand Down
30 changes: 15 additions & 15 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
transform,
pre_transform,
pre_filter,
properties=("y", "neg_dy", "q", "pq", "dp"),
properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"),
)

@property
Expand Down Expand Up @@ -188,14 +188,14 @@ def _load_confs_1_0(mol, n_atoms):
assert neg_dy.shape == pos.shape

assert conf["partial_charges"].attrs["units"] == "e"
pq = pt.tensor(conf["partial_charges"][:], dtype=pt.float32)
assert pq.shape == (n_atoms,)
partial_charges = pt.tensor(conf["partial_charges"][:], dtype=pt.float32)
assert partial_charges.shape == (n_atoms,)

assert conf["dipole_moment"].attrs["units"] == "e*Å"
dp = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32)
assert dp.shape == (3,)
dipole_moment = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32)
assert dipole_moment.shape == (3,)

yield pos, y, neg_dy, pq, dp
yield pos, y, neg_dy, partial_charges, dipole_moment

@staticmethod
def _load_confs_2_0(mol, n_atoms):
Expand All @@ -213,19 +213,19 @@ def _load_confs_2_0(mol, n_atoms):
assert all_neg_dy.shape == all_pos.shape

assert mol["partial_charges"].attrs["units"] == "e"
all_pq = pt.tensor(mol["partial_charges"][...], dtype=pt.float32)
assert all_pq.shape == (n_confs, n_atoms)
all_partial_charges = pt.tensor(mol["partial_charges"][...], dtype=pt.float32)
assert all_partial_charges.shape == (n_confs, n_atoms)

assert mol["dipole_moments"].attrs["units"] == "e*Å"
all_dp = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32)
assert all_dp.shape == (n_confs, 3)
all_dipole_moment = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32)
assert all_dipole_moment.shape == (n_confs, 3)

for pos, y, neg_dy, pq, dp in zip(all_pos, all_y, all_neg_dy, all_pq, all_dp):
for pos, y, neg_dy, partial_charges, dipole_moment in zip(all_pos, all_y, all_neg_dy, all_partial_charges, all_dipole_moment):
# Skip failed calculations
if y.isnan():
continue

yield pos, y, neg_dy, pq, dp
yield pos, y, neg_dy, partial_charges, dipole_moment

def sample_iter(self, mol_ids=False):
assert self.subsample_molecules > 0
Expand Down Expand Up @@ -261,9 +261,9 @@ def sample_iter(self, mol_ids=False):

z = pt.tensor(mol["atomic_numbers"], dtype=pt.long)
fq = pt.tensor(mol["formal_charges"], dtype=pt.long)
q = fq.sum()
total_charge = fq.sum()

for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(
for i_conf, (pos, y, neg_dy, partial_charges, dipole_moment) in enumerate(
load_confs(mol, n_atoms=len(z))
):
# Skip samples with large forces
Expand All @@ -273,7 +273,7 @@ def sample_iter(self, mol_ids=False):

# Create a sample
args = dict(
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=total_charge, partial_charges=partial_charges, dipole_moment=dipole_moment
)
if mol_ids:
args["mol_id"] = mol_id
Expand Down
8 changes: 8 additions & 0 deletions torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs):
self.fields.append(
("partial_charges", "partial_charges", torch.float32)
)
if "total_charge" in group:
self.fields.append(
("total_charge", "total_charge", torch.float32)
)
if "spin" in group:
self.fields.append(
("spin", "spin", torch.float32)
)
assert ("energy" in group) or (
"forces" in group
), "Each group must contain at least energies or forces"
Expand Down
90 changes: 45 additions & 45 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class MemmappedDataset(Dataset):
- :obj:`pos`: Positions of the atoms.
- :obj:`y`: Energy of the conformation.
- :obj:`neg_dy`: Forces on the atoms.
- :obj:`q`: Total charge of the conformation.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the conformation.
- :obj:`total_charge`: Total charge of the conformation.
- :obj:`partial_charges`: Partial charges of the atoms.
- :obj:`dipole_moment`: Dipole moment of the conformation.

The data is stored in the following files:

Expand All @@ -28,9 +28,9 @@ class MemmappedDataset(Dataset):
- :obj:`name.pos.mmap`: Positions of all the atoms.
- :obj:`name.y.mmap`: Energy of each conformation.
- :obj:`name.neg_dy.mmap`: Forces on all the atoms.
- :obj:`name.q.mmap`: Total charge of each conformation.
- :obj:`name.pq.mmap`: Partial charges of all the atoms.
- :obj:`name.dp.mmap`: Dipole moment of each conformation.
- :obj:`name.total_charge.mmap`: Total charge of each conformation.
- :obj:`name.partial_charges.mmap`: Partial charges of all the atoms.
- :obj:`name.dipole_moment.mmap`: Dipole moment of each conformation.

Args:
root (str): Root directory where the dataset should be stored.
Expand All @@ -45,8 +45,8 @@ class MemmappedDataset(Dataset):
indicating whether the data object should be included in the final
dataset.
properties (tuple of str, optional): The properties to include in the
dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`,
:obj:`pq`, and :obj:`dp`.
dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`total_charge`,
:obj:`partial_charges`, and :obj:`dipole_moment`.
"""

def __init__(
Expand All @@ -55,7 +55,7 @@ def __init__(
transform=None,
pre_transform=None,
pre_filter=None,
properties=("y", "neg_dy", "q", "pq", "dp"),
properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"),
):
self.name = self.__class__.__name__
self.properties = properties
Expand All @@ -76,13 +76,13 @@ def __init__(
self.neg_dy_mm = np.memmap(
fnames["neg_dy"], mode="r", dtype=np.float32, shape=(num_all_atoms, 3)
)
if "q" in self.properties:
self.q_mm = np.memmap(fnames["q"], mode="r", dtype=np.int8)
if "pq" in self.properties:
self.pq_mm = np.memmap(fnames["pq"], mode="r", dtype=np.float32)
if "dp" in self.properties:
if "total_charge" in self.properties:
self.q_mm = np.memmap(fnames["total_charge"], mode="r", dtype=np.int8)
if "partial_charges" in self.properties:
self.pq_mm = np.memmap(fnames["partial_charges"], mode="r", dtype=np.float32)
if "dipole_moment" in self.properties:
self.dp_mm = np.memmap(
fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
fnames["dipole_moment"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
)

assert self.idx_mm[0] == 0
Expand Down Expand Up @@ -151,20 +151,20 @@ def process(self):
dtype=np.float32,
shape=(num_all_atoms, 3),
)
if "q" in self.properties:
if "total_charge" in self.properties:
q_mm = np.memmap(
fnames["q"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs
fnames["total_charge"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs
)
if "pq" in self.properties:
if "partial_charges" in self.properties:
pq_mm = np.memmap(
fnames["pq"] + ".tmp",
fnames["partial_charges"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=num_all_atoms,
)
if "dp" in self.properties:
if "dipole_moment" in self.properties:
dp_mm = np.memmap(
fnames["dp"] + ".tmp",
fnames["dipole_moment"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=(num_all_confs, 3),
Expand All @@ -182,12 +182,12 @@ def process(self):
y_mm[i_conf] = data.y
if "neg_dy" in self.properties:
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy
if "q" in self.properties:
q_mm[i_conf] = data.q.to(pt.int8)
if "pq" in self.properties:
pq_mm[i_atom:i_next_atom] = data.pq
if "dp" in self.properties:
dp_mm[i_conf] = data.dp
if "total_charge" in self.properties:
q_mm[i_conf] = data.total_charge.to(pt.int8)
if "partial_charges" in self.properties:
pq_mm[i_atom:i_next_atom] = data.partial_charges
if "dipole_moment" in self.properties:
dp_mm[i_conf] = data.dipole_moment
i_atom = i_next_atom

idx_mm[-1] = num_all_atoms
Expand All @@ -200,11 +200,11 @@ def process(self):
y_mm.flush()
if "neg_dy" in self.properties:
neg_dy_mm.flush()
if "q" in self.properties:
if "total_charge" in self.properties:
q_mm.flush()
if "pq" in self.properties:
if "partial_charges" in self.properties:
pq_mm.flush()
if "dp" in self.properties:
if "dipole_moment" in self.properties:
dp_mm.flush()

os.rename(idx_mm.filename, fnames["idx"])
Expand All @@ -214,12 +214,12 @@ def process(self):
os.rename(y_mm.filename, fnames["y"])
if "neg_dy" in self.properties:
os.rename(neg_dy_mm.filename, fnames["neg_dy"])
if "q" in self.properties:
os.rename(q_mm.filename, fnames["q"])
if "pq" in self.properties:
os.rename(pq_mm.filename, fnames["pq"])
if "dp" in self.properties:
os.rename(dp_mm.filename, fnames["dp"])
if "total_charge" in self.properties:
os.rename(q_mm.filename, fnames["total_charge"])
if "partial_charges" in self.properties:
os.rename(pq_mm.filename, fnames["partial_charges"])
if "dipole_moment" in self.properties:
os.rename(dp_mm.filename, fnames["dipole_moment"])

def len(self):
return len(self.idx_mm) - 1
Expand All @@ -233,9 +233,9 @@ def get(self, idx):
- :obj:`pos`: Positions of the atoms.
- :obj:`y`: Formation energy of the molecule.
- :obj:`neg_dy`: Forces on the atoms.
- :obj:`q`: Total charge of the molecule.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the molecule.
- :obj:`total_charge`: Total charge of the molecule.
- :obj:`partial_charges`: Partial charges of the atoms.
- :obj:`dipole_moment`: Dipole moment of the molecule.

Args:
idx (int): Index of the data object.
Expand All @@ -252,10 +252,10 @@ def get(self, idx):
props["y"] = pt.tensor(self.y_mm[idx]).view(1, 1)
if "neg_dy" in self.properties:
props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms])
if "q" in self.properties:
props["q"] = pt.tensor(self.q_mm[idx], dtype=pt.long)
if "pq" in self.properties:
props["pq"] = pt.tensor(self.pq_mm[atoms])
if "dp" in self.properties:
props["dp"] = pt.tensor(self.dp_mm[idx])
if "total_charge" in self.properties:
props["total_charge"] = pt.tensor(self.q_mm[idx], dtype=pt.long)
if "partial_charges" in self.properties:
props["partial_charges"] = pt.tensor(self.pq_mm[atoms])
if "dipole_moment" in self.properties:
props["dipole_moment"] = pt.tensor(self.dp_mm[idx])
return Data(z=z, pos=pos, **props)
4 changes: 2 additions & 2 deletions torchmdnet/datasets/qm9q.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
transform,
pre_transform,
pre_filter,
properties=("y", "neg_dy", "q", "pq", "dp"),
properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"),
)

@property
Expand Down Expand Up @@ -150,7 +150,7 @@ def sample_iter(self, mol_ids=False):

# Create a sample
args = dict(
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=q, partial_charges=pq, dipole_moment=dp
)
if mol_ids:
args["mol_id"] = mol_id
Expand Down
Loading
Loading