diff --git a/mantidimaging/core/data/dataset.py b/mantidimaging/core/data/dataset.py index bda476f6aa6..5086a551094 100644 --- a/mantidimaging/core/data/dataset.py +++ b/mantidimaging/core/data/dataset.py @@ -73,18 +73,31 @@ def all(self) -> list[ImageStack]: return self.recons.stacks + self._stacks + remove_nones(named_stacks) def delete_stack(self, images_id: uuid.UUID) -> None: - for recon in self.recons: - if recon.id == images_id: - self.recons.remove(recon) - return - for image in self._stacks: - if image.id == images_id: - self._stacks.remove(image) - return - if self.sinograms is not None and self.sinograms.id == images_id: + if isinstance(self.sample, ImageStack) and self.sample.id == images_id: + self.sample = None + elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id: + self.flat_before = None + elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id: + self.flat_after = None + elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id: + self.dark_before = None + elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id: + self.dark_after = None + elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id: + assert self.sample is not None + self.sample.clear_proj180deg() + elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id: self.sinograms = None - return - raise KeyError(_delete_stack_error_message(images_id)) + else: + for recon in self.recons: + if recon.id == images_id: + self.recons.remove(recon) + return + for image in self._stacks: + if image.id == images_id: + self._stacks.remove(image) + return + raise KeyError(_delete_stack_error_message(images_id)) def __contains__(self, images_id: uuid.UUID) -> bool: return any(image.id == images_id for image in self.all) @@ -159,29 +172,6 @@ class MixedDataset(BaseDataset): class StrictDataset(BaseDataset): - def delete_stack(self, images_id: uuid.UUID) -> None: - if isinstance(self.sample, ImageStack) and self.sample.id == images_id: - self.sample = None # type: ignore - elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id: - self.flat_before = None - elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id: - self.flat_after = None - elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id: - self.dark_before = None - elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id: - self.dark_after = None - elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id: - assert self.sample is not None - self.sample.clear_proj180deg() - elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id: - self.sinograms = None - elif images_id in self.recons.ids: - for recon in self.recons: - if recon.id == images_id: - self.recons.remove(recon) - else: - raise KeyError(_delete_stack_error_message(images_id)) - def set_stack(self, file_type: FILE_TYPES, image_stack: ImageStack) -> None: attr_name = file_type.fname.lower().replace(" ", "_") if file_type == FILE_TYPES.PROJ_180: diff --git a/mantidimaging/core/data/test/dataset_test.py b/mantidimaging/core/data/test/dataset_test.py index 4ba348d8866..f4a31608076 100644 --- a/mantidimaging/core/data/test/dataset_test.py +++ b/mantidimaging/core/data/test/dataset_test.py @@ -231,3 +231,28 @@ def test_partially_incomplete_nexus_rotation_angles(self): ] assert np.array_equal(expected_list, ds.nexus_rotation_angles) + + def test_delete_sample(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[0].id) + self.assertIsNone(ds.sample) + + def test_delete_flat_before(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[1].id) + self.assertIsNone(ds.flat_before) + + def test_delete_flat_after(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[2].id) + self.assertIsNone(ds.flat_after) + + def test_delete_dark_before(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[3].id) + self.assertIsNone(ds.dark_before) + + def test_delete_dark_after(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[4].id) + self.assertIsNone(ds.dark_after)