diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index 3d817a50..ee6a5e0f 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -53,9 +53,13 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs): self.fields.append(("y", "energy", torch.float32)) if "forces" in group: self.fields.append(("neg_dy", "forces", torch.float32)) + if "charge" in group: + self.fields.append(("q", "charge", torch.float32)) + if "spin" in group: + self.fields.append(("s", "spin", torch.float32)) if "partial_charges" in group: self.fields.append( - ("partial_charges", "partial_charges", torch.float32) + ("pq", "partial_charges", torch.float32) ) assert ("energy" in group) or ( "forces" in group