From 30888d4565905c09e1d153077d5b44cf725c36ed Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 8 Oct 2024 11:09:58 -0400 Subject: [PATCH] added simple tests of step save methods --- src/stpipe/step.py | 2 +- tests/test_step.py | 123 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/src/stpipe/step.py b/src/stpipe/step.py index 778eeb7d..683cf55b 100644 --- a/src/stpipe/step.py +++ b/src/stpipe/step.py @@ -990,7 +990,7 @@ def save_model( elif isinstance(model, Sequence) and not isinstance(model, str): if not hasattr(model, "save"): - # list of datamodels, e.g. ModelContainer + # list of datamodels, e.g. JWST ModelContainer output_paths = [] for i, m in enumerate(model): # ignore list of lists. individual steps should handle this diff --git a/tests/test_step.py b/tests/test_step.py index 972f84a9..ea38a0b4 100644 --- a/tests/test_step.py +++ b/tests/test_step.py @@ -1,7 +1,9 @@ """Test step.Step""" +import copy import logging import re +from collections.abc import Sequence from typing import ClassVar import asdf @@ -9,6 +11,7 @@ import stpipe.config_parser as cp from stpipe import cmdline +from stpipe.datamodel import AbstractDataModel from stpipe.pipeline import Pipeline from stpipe.step import Step @@ -411,3 +414,123 @@ def test_log_records(): pipeline.run() assert any(r == "This step has called out a warning." for r in pipeline.log_records) + + +class StepWithModel(Step): + """A step that immediately saves the model it gets passed in""" + + spec = """ + output_ext = string(default='simplestep') + save_results = boolean(default=True) + """ + # spec = """ + # + # skip = bool(default=False) + # """ + + def process(self, input_model): + return input_model + + +class SimpleDataModel(AbstractDataModel): + """A simple data model""" + + @property + def crds_observatory(self): + return "jwst" + + # @property + # def meta(self): + # return {"filename": "test.fits"} + + def get_crds_parameters(self): + return {"test": "none"} + + def save(self, path, dir_path=None, *args, **kwargs): + saveid = getattr(self, "saveid", None) + if saveid is not None: + print(f"here {saveid}") + fname = saveid+"-saved.txt" + with open(fname, "w") as f: + f.write(f"{path}") + return fname + return None + + +def test_save(tmp_cwd): + + model = SimpleDataModel() + model.saveid = "test" + step = StepWithModel() + step.run(model) + assert (tmp_cwd / "test-saved.txt").exists() + + +@pytest.fixture(scope="function") +def model_list(): + model = SimpleDataModel() + model_list = [copy.deepcopy(model) for _ in range(3)] + for i, model in enumerate(model_list): + model.saveid = f"test{i}" + return model_list + + +def test_save_list(tmp_cwd, model_list): + step = StepWithModel() + step.run(model_list) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +class SimpleContainer(Sequence): + + def __init__(self, models): + self._models = models + + def __len__(self): + return len(self._models) + + def __getitem__(self, idx): + return self._models[idx] + + def __iter__(self): + yield from self._models + + def insert(self, index, model): + self._models.insert(index, model) + + def append(self, model): + self._models.append(model) + + def extend(self, model): + self._models.extend(model) + + def pop(self, index=-1): + self._models.pop(index) + + +def test_save_container(tmp_cwd, model_list): + """ensure list-like save still works for non-list sequence""" + container = SimpleContainer(model_list) + step = StepWithModel() + step.run(container) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +def test_save_tuple_with_nested_list(tmp_cwd, model_list): + """ + in rare cases, multiple outputs are returned from step as tuple. + One example is the jwst badpix_selfcal step, which returns one sci exposure + and a list containing an arbitrary number of background exposures. + Expected behavior in this case, at least at time of writing, is to save the + science exposure and ignore the list + """ + single_model = SimpleDataModel() + single_model.saveid = "test" + container = (single_model, model_list) + step = StepWithModel() + step.run(container) + assert (tmp_cwd / "test-saved.txt").exists() + for i in range(3): + assert not (tmp_cwd / f"test{i}-saved.txt").exists()