From 1d4bf9ab4db2398ad7a9febf7ea44cc9933c9dfc Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 13 Jun 2024 18:00:48 -0700 Subject: [PATCH] PyTorch Train: add tests for using arg indexes for model inputs --- tests/cases/torch_train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 8fb9e8ec..c66b56f2 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -86,7 +86,8 @@ def forward(self, a, b): ), ], ) -def test_loss_drops(tmpdir, device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_loss_drops(tmpdir, device, input_args): checkpoint_basename = str(tmpdir / "model") a_key = ArrayKey("A") @@ -104,7 +105,7 @@ def test_loss_drops(tmpdir, device): model=model, optimizer=optimizer, loss=loss, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, loss_inputs={0: c_predicted_key, 1: c_key}, outputs={0: c_predicted_key}, gradients={0: c_gradient_key}, @@ -167,7 +168,8 @@ def test_loss_drops(tmpdir, device): ), ], ) -def test_output(device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_spawn_subprocess(device, input_args): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -181,7 +183,7 @@ def test_output(device): source = example_train_source(a_key, b_key, c_key) predict = Predict( model=model, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, outputs={"linear": c_pred, 0: d_pred}, array_specs={ c_key: ArraySpec(nonspatial=True),