Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchscipt compatability for Ensemble #312

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sef43
Copy link
Collaborator

@sef43 sef43 commented Apr 2, 2024

I could not make Union[ Tuple[Tensor,Tensor], Tuple[Tensor,Tensor,Tensor,Tensor]] work.
It will jit.scipt but then if I try and use the model as
energy,_ = model(..) it will complain:


RuntimeError: 
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]] cannot be used as a tuple:

    def forward(self, positions):
        positions = pt.index_select(positions, 0, self.all_atom_indices).to(pt.float32) * 10 # nm --> A
        energies, _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

with this I can do:

energies,* _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)

And I can change between ensemble and single model without any changes


if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
return y_mean, neg_dy_mean, None, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in returning None, None. The only reason for the if was to return less variables.
Remove the if and always return y_mean, neg_dy_mean, y_std, neg_dy_std

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And remove the class option and also from the loader.


Returns:
Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: The mean output of the models, the mean derivatives, the std of the outputs if return_std is true, the std of the derivatives if return_std is true.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

say "the mean negative derivatives"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe something like this should stay in this docstring:

Average predictions over all models in the ensemble.
        The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
        ```

@RaulPPelaez
Copy link
Collaborator

Please add a test to check for TorchScript compatibility.
Is Ensemble supposed to be able to mimic TorchMD_Net? In that case it should really return just y, neg_dy by default, without Nones.

@stefdoerr
Copy link
Collaborator

No I think at this point we have given up trying to mimic TorchMD_Net class outputs.

q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:
Note that this function will fail if derivative=False.

@RaulPPelaez
Copy link
Collaborator

This works:

import torch
from torch import Tensor
from typing import Union, Tuple


class Mymod(torch.nn.Module):

    def __init__(self, return3=False):
        super(Mymod, self).__init__()
        self.return3 = return3

    def forward(
        self, z: Tensor
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if self.return3:
            return z, z, z
        else:
            return z, z

mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)

@sef43
Copy link
Collaborator Author

sef43 commented Apr 2, 2024

yes but I cant make this work:

import torch
from torch import Tensor
from typing import Union, Tuple


class Mymod(torch.nn.Module):

    def __init__(self, return3=False):
        super(Mymod, self).__init__()
        self.return3 = return3

    def forward(
        self, z: Tensor
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if self.return3:
            return z, z, z
        else:
            return z, z


class MymodWrapper(torch.nn.Module):
    def __init__(self):
        super(MymodWrapper, self).__init__()

        self.model = Mymod()
        
    def forward(
        self, z: Tensor
    ) -> Tensor:

        o1,_ = self.model(z)

        return o1


mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)

mymodwrapper = MymodWrapper()

o1 = mymodwrapper(z)

mymodwrapper = torch.jit.script(MymodWrapper())

o1 = mymodwrapper(z)
Traceback (most recent call last):
  File "/home/sfarr/torchmd-net/tests/temp.py", line 48, in <module>
    mymodwrapper = torch.jit.script(MymodWrapper())
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
    create_methods_and_properties_from_stubs(
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(
RuntimeError: 
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] cannot be used as a tuple:
  File "/home/sfarr/torchmd-net/tests/temp.py", line 31
    ) -> Tensor:
    
        o1,_ = self.model(z)
               ~~~~~~~~~~~~ <--- HERE
    
        return o1

For OpenMM-torch compatability there will be wrapper code that needs to be torchscripted too

@RaulPPelaez
Copy link
Collaborator

RaulPPelaez commented Apr 2, 2024

Dang! That looks like a torch bug. I do not see why one would work and the other not.
Ok, lets drop TorchMD_Net compatibility then for the moment.
EDIT: I opened this pytorch/pytorch#123168

ensemble_model = load_model(ensemble_zip, return_std=True)

if check_script:
ensemble_model = torch.jit.script(load_model(ckpts))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo here, this should be ensemble_zip not ckpts

@sef43
Copy link
Collaborator Author

sef43 commented Apr 2, 2024

no need for urgent merge here, we will see if it works in production simulation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants