-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torchscipt compatability for Ensemble #312
base: main
Are you sure you want to change the base?
Conversation
torchmdnet/models/model.py
Outdated
|
||
if self.return_std: | ||
return y_mean, neg_dy_mean, y_std, neg_dy_std | ||
else: | ||
return y_mean, neg_dy_mean | ||
return y_mean, neg_dy_mean, None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No point in returning None, None. The only reason for the if was to return less variables.
Remove the if and always return y_mean, neg_dy_mean, y_std, neg_dy_std
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And remove the class option and also from the loader.
torchmdnet/models/model.py
Outdated
|
||
Returns: | ||
Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: The mean output of the models, the mean derivatives, the std of the outputs if return_std is true, the std of the derivatives if return_std is true. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
say "the mean negative derivatives"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe something like this should stay in this docstring:
Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
```
Please add a test to check for TorchScript compatibility. |
No I think at this point we have given up trying to mimic TorchMD_Net class outputs. |
torchmdnet/models/model.py
Outdated
q: Optional[Tensor] = None, | ||
s: Optional[Tensor] = None, | ||
extra_args: Optional[Dict[str, Tensor]] = None, | ||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:
Note that this function will fail if derivative=False.
This works: import torch
from torch import Tensor
from typing import Union, Tuple
class Mymod(torch.nn.Module):
def __init__(self, return3=False):
super(Mymod, self).__init__()
self.return3 = return3
def forward(
self, z: Tensor
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
if self.return3:
return z, z, z
else:
return z, z
mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z) |
yes but I cant make this work: import torch
from torch import Tensor
from typing import Union, Tuple
class Mymod(torch.nn.Module):
def __init__(self, return3=False):
super(Mymod, self).__init__()
self.return3 = return3
def forward(
self, z: Tensor
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
if self.return3:
return z, z, z
else:
return z, z
class MymodWrapper(torch.nn.Module):
def __init__(self):
super(MymodWrapper, self).__init__()
self.model = Mymod()
def forward(
self, z: Tensor
) -> Tensor:
o1,_ = self.model(z)
return o1
mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)
mymodwrapper = MymodWrapper()
o1 = mymodwrapper(z)
mymodwrapper = torch.jit.script(MymodWrapper())
o1 = mymodwrapper(z)
For OpenMM-torch compatability there will be wrapper code that needs to be torchscripted too |
Dang! That looks like a torch bug. I do not see why one would work and the other not. |
ensemble_model = load_model(ensemble_zip, return_std=True) | ||
|
||
if check_script: | ||
ensemble_model = torch.jit.script(load_model(ckpts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo here, this should be ensemble_zip not ckpts
no need for urgent merge here, we will see if it works in production simulation |
I could not make
Union[ Tuple[Tensor,Tensor], Tuple[Tensor,Tensor,Tensor,Tensor]]
work.It will jit.scipt but then if I try and use the model as
energy,_ = model(..) it will complain:
with this I can do:
energies,* _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)
And I can change between ensemble and single model without any changes