-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow torch model to be used as Integrator #56
Comments
That assumes a different computational model from how OpenMM works. For example, it has no concept of computing forces for an arbitrary set of positions. It can only compute them for the current positions that are set in the context. Instead, it would be best to define functions that closely map to the operations in CustomIntegrator: get and set positions, get and set velocities, compute forces and/or energy for the current positions, apply constraints, etc. |
It would be possible to implement a support for such models: class IntegratorModule(torch.nn.Module):
def forward(self, positions, velocities, forces, arbitrary_number_of_scalars):
# Do some computation
return new_positions, new_velocities |
That assumes you only need the forces at the start of the step. Integrators often need the forces at multiple points throughout the step. Here's an example of a velocity verlet integrator implemented with CustomIntegrator: CustomIntegrator integrator(0.001);
integrator.addPerDofVariable("x1", 0);
integrator.addUpdateContextState();
integrator.addComputePerDof("v", "v+0.5*dt*f/m");
integrator.addComputePerDof("x", "x+dt*v");
integrator.addComputePerDof("x1", "x");
integrator.addConstrainPositions();
integrator.addComputePerDof("v", "v+0.5*dt*f/m+(x-x1)/dt");
integrator.addConstrainVelocities(); I imagine that the PyTorch implementation might look something like this: class VelocityVerletIntegrator(openmmtorch.Integrator):
def forward(self):
dt = self.dt
m = self.m
self.updateContextState()
self.v = self.v+0.5*dt*self.f/m
x1 = self.x+dt*self.v
self.x = x1
self.constrainPositions()
self.v = self.v+0.5*dt*self.f/m+(self.x-x1)/dt
self.constrainVelocities() I assumed that |
Under the hood, PyTorch constructs a computational graph, which represents the operations and associated input-output dependencies. Probably it is possible to wrap class Integrator(openmmtorch.Integrator):
def forward(self, state):
dt = self.dt
m = self.m
state = updateState(state)
v = getVelocities(state)
f = getForces(state)
x = getPositions(state)
v = v+0.5*dt*f/m
x1 = x+dt*v
state = setPositions(state, x1)
state = constrainPositions(state)
x = getPositions(state)
f = getForces(state)
v = v+0.5*dt*f/m+(x-x1)/dt
state = setVelocities(state, v)
state = constrainVelocities(state)
return state |
@jchodera do you have some examples which go beyond the capability of |
I wonder if we could also support torch models used as
Integrator
s in OpenMM.Perhaps something like this could work:
Here, we would have to extend TorchScript with custom Ops
openmm_compute_forces
,openmm_compute_potential
, andopenmm_compute_potential_and_forces
, which would wrap the normal OpenMM energy/force computation. Optimally, these C++ functions would know when the force or potential did not need to be recomputed (because no particles moved) if called at the end of one step and at the beginning of the next step.To use the integrator in a simulation, the user would create a
TorchIntegrator
object that would behave much like a normalIntegrator
:Edit: It would also be important to enable the integrator to modify global parameters, as well as define its own that can be accessed through the OpenMM API. I'm not quite sure how that would work, however.
The text was updated successfully, but these errors were encountered: