diff --git a/tests/test_step.py b/tests/test_step.py index 972f84a9..ea22e9be 100644 --- a/tests/test_step.py +++ b/tests/test_step.py @@ -1,7 +1,9 @@ """Test step.Step""" +import copy import logging import re +from collections.abc import Sequence from typing import ClassVar import asdf @@ -9,6 +11,7 @@ import stpipe.config_parser as cp from stpipe import cmdline +from stpipe.datamodel import AbstractDataModel from stpipe.pipeline import Pipeline from stpipe.step import Step @@ -411,3 +414,162 @@ def test_log_records(): pipeline.run() assert any(r == "This step has called out a warning." for r in pipeline.log_records) + + +class StepWithModel(Step): + """A step that immediately saves the model it gets passed in""" + + spec = """ + output_ext = string(default='simplestep') + save_results = boolean(default=True) + """ + + def process(self, input_model): + # make a change to ensure step skip is working + # without having to define SimpleDataModel.meta.stepname + if isinstance(input_model, SimpleDataModel): + input_model.stepstatus = "COMPLETED" + elif isinstance(input_model, SimpleContainer): + for model in input_model: + model.stepstatus = "COMPLETED" + return input_model + + +class SimpleDataModel(AbstractDataModel): + """A simple data model""" + + @property + def crds_observatory(self): + return "jwst" + + def get_crds_parameters(self): + return {"test": "none"} + + def save(self, path, dir_path=None, *args, **kwargs): + saveid = getattr(self, "saveid", None) + if saveid is not None: + fname = saveid + "-saved.txt" + with open(fname, "w") as f: + f.write(f"{path}") + return fname + return None + + +def test_save(tmp_cwd): + + model = SimpleDataModel() + model.saveid = "test" + step = StepWithModel() + step.run(model) + assert (tmp_cwd / "test-saved.txt").exists() + + +def test_skip(): + model = SimpleDataModel() + step = StepWithModel() + step.skip = True + out = step.run(model) + assert not hasattr(out, "stepstatus") + assert out is model + + +@pytest.fixture(scope="function") +def model_list(): + model = SimpleDataModel() + model_list = [copy.deepcopy(model) for _ in range(3)] + for i, model in enumerate(model_list): + model.saveid = f"test{i}" + return model_list + + +def test_save_list(tmp_cwd, model_list): + step = StepWithModel() + step.run(model_list) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +class SimpleContainer(Sequence): + + def __init__(self, models): + self._models = models + + def __len__(self): + return len(self._models) + + def __getitem__(self, idx): + return self._models[idx] + + def __iter__(self): + yield from self._models + + def insert(self, index, model): + self._models.insert(index, model) + + def append(self, model): + self._models.append(model) + + def extend(self, model): + self._models.extend(model) + + def pop(self, index=-1): + self._models.pop(index) + + +class SimpleContainerWithSave(SimpleContainer): + + def save(self, path, dir_path=None, *args, **kwargs): + for model in self._models[1:]: + # skip the first model to test that the save method is called + # rather than just looping over all models like in the without-save case + model.save(path, dir_path, *args, **kwargs) + + +@pytest.mark.xfail( + reason="Looping over models only works for list and tuple. This should be fixed." +) +def test_save_container(tmp_cwd, model_list): + """ensure list-like save still works for non-list sequence""" + container = SimpleContainer(model_list) + step = StepWithModel() + step.run(container) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +def test_skip_container(tmp_cwd, model_list): + step = StepWithModel() + step.skip = True + out = step.run(model_list) + assert not hasattr(out, "stepstatus") + for i, model in enumerate(out): + assert not hasattr(model, "stepstatus") + assert model_list[i] is model + + +def test_save_container_with_save_method(tmp_cwd, model_list): + """ensure top-level save method is called for sequence""" + container = SimpleContainerWithSave(model_list) + step = StepWithModel() + step.run(container) + assert not (tmp_cwd / "test0-saved.txt").exists() + assert (tmp_cwd / "test1-saved.txt").exists() + assert (tmp_cwd / "test2-saved.txt").exists() + + +def test_save_tuple_with_nested_list(tmp_cwd, model_list): + """ + in rare cases, multiple outputs are returned from step as tuple. + One example is the jwst badpix_selfcal step, which returns one sci exposure + and a list containing an arbitrary number of background exposures. + Expected behavior in this case, at least at time of writing, is to save the + science exposure and ignore the list + """ + single_model = SimpleDataModel() + single_model.saveid = "test" + container = (single_model, model_list) + step = StepWithModel() + step.run(container) + assert (tmp_cwd / "test-saved.txt").exists() + for i in range(3): + assert not (tmp_cwd / f"test{i}-saved.txt").exists()