Skip to content

Commit

Permalink
Enable the testing of CUDA Graph
Browse files Browse the repository at this point in the history
  • Loading branch information
Raimondas Galvelis committed Feb 25, 2022
1 parent 11ead20 commit 6c22c88
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tests/TestTorchForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

@pytest.mark.parametrize('platform', ['Reference', 'CUDA'])
@pytest.mark.parametrize('precision', ['single', 'mixed', 'double'])
def testEnergyForces(platform, precision):
@pytest.mark.parametrize('useGraph', [True, False])
def testEnergyForces(platform, precision, useGraph):

if pt.cuda.device_count() < 1 and platform == 'CUDA':
pytest.skip('A CUDA device is not available')

# Create a system
numParticles = 10
Expand All @@ -19,6 +23,8 @@ def testEnergyForces(platform, precision):

# Create a TorchForce
force = ot.TorchForce("../../tests/central.pt")
if useGraph:
force.setPlatformProperty('CUDAGraph', 'true')
system.addForce(force)

# Setup a simulation
Expand Down

0 comments on commit 6c22c88

Please sign in to comment.