Skip to content

Commit

Permalink
less invasive changes to save logic in step
Browse files Browse the repository at this point in the history
  • Loading branch information
emolter committed Oct 9, 2024
1 parent 843c18c commit 15936e9
Showing 1 changed file with 49 additions and 41 deletions.
90 changes: 49 additions & 41 deletions src/stpipe/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,11 @@ def run(self, *args):
)
library.shelve(model, i)

elif isinstance(args[0], Sequence) and self.class_alias is not None:
elif (isinstance(args[0], Sequence)) and \
(not isinstance(args[0], str)) and \
(self.class_alias is not None):
# handle ModelContainer or list of models
if isinstance(args[0][0], AbstractDataModel):
if args[0] and isinstance(args[0][0], AbstractDataModel):
for model in args[0]:
try:
setattr(
Expand All @@ -515,8 +517,7 @@ def run(self, *args):
] = "SKIPPED"
except AttributeError as e:
self.log.info(
"Could not record skip into DataModel"
" header: %s",
"Could not record skip into DataModel header: %s",
e,
)
step_result = args[0]
Expand Down Expand Up @@ -560,7 +561,37 @@ def run(self, *args):

# Save the output file if one was specified
if not self.skip and self.save_results:
self.save_model(step_result)
# Setup the save list.
if isinstance(step_result, Sequence):
if hasattr(step_result, "save") or isinstance(step_result, str):
results_to_save = [step_result]
else:
results_to_save = step_result
else:
results_to_save = [step_result]

for idx, result in enumerate(results_to_save):
if len(results_to_save) <= 1:
idx = None
if isinstance(
result, (AbstractDataModel | AbstractModelLibrary)
):
self.save_model(result, idx=idx)
elif hasattr(result, "save"):
try:
output_path = self.make_output_path(idx=idx)
except AttributeError:
self.log.warning(
"`save_results` has been requested, but cannot"
" determine filename."
)
self.log.warning(
"Specify an output file with `--output_file` or set"
" `--save_results=false`"
)
else:
self.log.info("Saving file %s", output_path)
result.save(output_path, overwrite=True)

if not self.skip:
self.log.info("Step %s done", self.name)
Expand Down Expand Up @@ -988,39 +1019,18 @@ def save_model(
model.shelve(m, i)
return output_paths

elif isinstance(model, Sequence) and not isinstance(model, str):
if not hasattr(model, "save"):
# list of datamodels, e.g. JWST ModelContainer
output_paths = []
for i, m in enumerate(model):
# ignore list of lists. individual steps should handle this
if not isinstance(m, Sequence):
idx = None if len(model) == 1 else i
output_paths.append(
self.save_model(
m,
idx=idx,
suffix=suffix,
force=force,
**components,
)
)
return output_paths
else:
# JWST SourceModelContainer takes this path
save_model_func = partial(
self.save_model,
suffix=suffix,
force=force,
**components,
)
output_path = model.save(
path=output_file,
save_model_func=save_model_func,
)
return output_path

elif hasattr(model, "save"):
elif isinstance(model, Sequence):
save_model_func = partial(
self.save_model,
suffix=suffix,
force=force,
**components,
)
output_path = model.save(
path=output_file,
save_model_func=save_model_func,
)
else:
# Search for an output file name.
if self.output_use_model or (
output_file is None and not self.search_output_file
Expand All @@ -1036,10 +1046,8 @@ def save_model(
)
)
self.log.info("Saved model in %s", output_path)
return output_path

else:
return
return output_path

@property
def make_output_path(self):
Expand Down

0 comments on commit 15936e9

Please sign in to comment.