diff --git a/mantidimaging/core/data/dataset.py b/mantidimaging/core/data/dataset.py index 993c78a4426..db7cf98d272 100644 --- a/mantidimaging/core/data/dataset.py +++ b/mantidimaging/core/data/dataset.py @@ -66,21 +66,38 @@ def sinograms(self, sino: ImageStack | None) -> None: @property def all(self) -> list[ImageStack]: - return self.recons.stacks + self._stacks + remove_nones([self._sinograms]) + named_stacks = [ + self.sample, self.proj180deg, self.flat_before, self.flat_after, self.dark_before, self.dark_after, + self.sinograms + ] + return self.recons.stacks + self._stacks + remove_nones(named_stacks) def delete_stack(self, images_id: uuid.UUID) -> None: - for recon in self.recons: - if recon.id == images_id: - self.recons.remove(recon) - return - for image in self._stacks: - if image.id == images_id: - self._stacks.remove(image) - return - if self.sinograms is not None and self.sinograms.id == images_id: + if isinstance(self.sample, ImageStack) and self.sample.id == images_id: + self.sample = None + elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id: + self.flat_before = None + elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id: + self.flat_after = None + elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id: + self.dark_before = None + elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id: + self.dark_after = None + elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id: + assert self.sample is not None + self.sample.clear_proj180deg() + elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id: self.sinograms = None - return - raise KeyError(_delete_stack_error_message(images_id)) + else: + for recon in self.recons: + if recon.id == images_id: + self.recons.remove(recon) + return + for image in self._stacks: + if image.id == images_id: + self._stacks.remove(image) + return + raise KeyError(_delete_stack_error_message(images_id)) def __contains__(self, images_id: uuid.UUID) -> bool: return any(image.id == images_id for image in self.all) @@ -102,20 +119,18 @@ def delete_recons(self) -> None: def add_stack(self, stack: ImageStack) -> None: self._stacks.append(stack) - -class MixedDataset(BaseDataset): - pass - - -class StrictDataset(BaseDataset): - @property - def all(self) -> list[ImageStack]: - image_stacks = [ - self.sample, self.proj180deg, self.flat_before, self.flat_after, self.dark_before, self.dark_after, - self.sinograms - ] - return remove_nones(image_stacks) + self.recons.stacks + def proj180deg(self) -> ImageStack | None: + if self.sample is not None: + return self.sample.proj180deg + else: + return None + + @proj180deg.setter + def proj180deg(self, proj180deg: ImageStack | None) -> None: + if self.sample is None: + raise RuntimeError("Can't set a 180 projection without a sample") + self.sample.proj180deg = proj180deg @property def _nexus_stack_order(self) -> list[ImageStack]: @@ -150,42 +165,6 @@ def image_keys(self) -> list[int]: image_keys += _image_key_list(2, self.dark_after.data.shape[0]) return image_keys - @property - def proj180deg(self) -> ImageStack | None: - if self.sample is not None: - return self.sample.proj180deg - else: - return None - - @proj180deg.setter - def proj180deg(self, proj180deg: ImageStack | None) -> None: - if self.sample is None: - raise RuntimeError("Can't set a 180 projection without a sample") - self.sample.proj180deg = proj180deg - - def delete_stack(self, images_id: uuid.UUID) -> None: - if isinstance(self.sample, ImageStack) and self.sample.id == images_id: - self.sample = None # type: ignore - elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id: - self.flat_before = None - elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id: - self.flat_after = None - elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id: - self.dark_before = None - elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id: - self.dark_after = None - elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id: - assert self.sample is not None - self.sample.clear_proj180deg() - elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id: - self.sinograms = None - elif images_id in self.recons.ids: - for recon in self.recons: - if recon.id == images_id: - self.recons.remove(recon) - else: - raise KeyError(_delete_stack_error_message(images_id)) - def set_stack(self, file_type: FILE_TYPES, image_stack: ImageStack) -> None: attr_name = file_type.fname.lower().replace(" ", "_") if file_type == FILE_TYPES.PROJ_180: @@ -206,6 +185,14 @@ def is_processed(self) -> bool: return False +class MixedDataset(BaseDataset): + pass + + +class StrictDataset(BaseDataset): + pass + + def _get_stack_data_type(stack_id: uuid.UUID, dataset: BaseDataset) -> str: """ Find the data type as a string of a stack. diff --git a/mantidimaging/core/data/test/dataset_test.py b/mantidimaging/core/data/test/dataset_test.py index a6af226dfb1..b0da8cf58ce 100644 --- a/mantidimaging/core/data/test/dataset_test.py +++ b/mantidimaging/core/data/test/dataset_test.py @@ -6,10 +6,38 @@ from unittest import mock import uuid +import numpy as np + +from mantidimaging.core.data import ImageStack from mantidimaging.core.data.dataset import BaseDataset, _get_stack_data_type +from mantidimaging.core.utility.data_containers import ProjectionAngles, FILE_TYPES from mantidimaging.test_helpers.unit_test_helper import generate_images +def _make_standard_dataset(shape=(2, 5, 5)): + #image_ids = [mock.create_autospec(uuid.UUID) for _ in range(6)] + image_stacks = [generate_images(shape) for _ in range(6)] + image_stacks[0].name = "samplename" + + ds = BaseDataset(sample=image_stacks[0], + flat_before=image_stacks[1], + flat_after=image_stacks[2], + dark_before=image_stacks[3], + dark_after=image_stacks[4]) + ds.proj180deg = image_stacks[5] + return ds, image_stacks + + +def _set_fake_projection_angles(image_stack: ImageStack): + """ + Sets the private projection angles attribute. + :param image_stack: The ImageStack object. + """ + image_stack._projection_angles = ProjectionAngles(np.array([0, 180])) + #image_stack.real_projection_angles.return_value = image_stack._projection_angles + #image_stack.projection_angles.return_value = image_stack._projection_angles + + class DatasetTest(unittest.TestCase): def test_create_dataset(self): @@ -70,6 +98,17 @@ def test_stacks_in_all(self): ds = BaseDataset(stacks=image_stacks) self.assertListEqual(ds.all, image_stacks) + def test_sample_in_all(self): + image_sample = mock.Mock(proj180deg=None) + ds = BaseDataset(sample=image_sample) + self.assertCountEqual(ds.all, [image_sample]) + + def test_all_for_full_dataset(self): + ds, image_stacks = _make_standard_dataset() + self.assertEqual(len(ds.all), len(image_stacks)) + for image in image_stacks: + self.assertIn(image, ds.all) + def test_delete_stack_from_stacks_list(self): image_stacks = [mock.Mock() for _ in range(3)] ds = BaseDataset(stacks=image_stacks) @@ -89,3 +128,163 @@ def test_get_stack_data_type_returns_images(self): images_id = images.id dataset = BaseDataset(stacks=[images]) self.assertEqual(_get_stack_data_type(images_id, dataset), "Images") + + def test_attribute_not_set_returns_none(self): + sample = mock.Mock() + dataset = BaseDataset(sample=sample) + + self.assertIsNone(dataset.flat_before) + self.assertIsNone(dataset.flat_after) + self.assertIsNone(dataset.dark_before) + self.assertIsNone(dataset.dark_after) + + def test_set_flat_before(self): + sample = mock.Mock() + dataset = BaseDataset(sample=sample) + flat_before = mock.Mock(id="1234") + dataset.flat_before = flat_before + self.assertIs(flat_before, dataset.flat_before) + self.assertIn("1234", dataset) + + def test_all_images_ids(self): + ds, images = _make_standard_dataset() + self.assertCountEqual(ds.all_image_ids, [image.id for image in images]) + + def test_nexus_stack_order(self): + ds, _ = _make_standard_dataset() + self.assertListEqual(ds._nexus_stack_order, + [ds.dark_before, ds.flat_before, ds.sample, ds.flat_after, ds.dark_after]) + + def test_nexus_arrays(self): + ds, _ = _make_standard_dataset() + self.assertListEqual( + ds.nexus_arrays, + [ds.dark_before.data, ds.flat_before.data, ds.sample.data, ds.flat_after.data, ds.dark_after.data]) + + def test_image_keys(self): + ds, images = _make_standard_dataset() + + self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 1, 1, 2, 2]) + + def test_missing_dark_before_image_keys(self): + ds, images = _make_standard_dataset() + ds.dark_before = None + + self.assertListEqual(ds.image_keys, [1, 1, 0, 0, 1, 1, 2, 2]) + + def test_missing_flat_before_image_keys(self): + ds, images = _make_standard_dataset() + ds.flat_before = None + + self.assertListEqual(ds.image_keys, [2, 2, 0, 0, 1, 1, 2, 2]) + + def test_missing_flat_after_image_keys(self): + ds, images = _make_standard_dataset() + ds.flat_after = None + + self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 2, 2]) + + def test_missing_dark_after_image_keys(self): + ds, images = _make_standard_dataset() + ds.dark_after = None + + self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 1, 1]) + + def test_no_sample_image_keys(self): + ds, images = _make_standard_dataset() + ds.sample = None + with self.assertRaises(RuntimeError): + _ = ds.image_keys + + def test_rotation_angles(self): + ds, images = _make_standard_dataset() + for stack in images: + _set_fake_projection_angles(stack) + assert np.array_equal(ds.nexus_rotation_angles, [ + ds.dark_before.projection_angles().value, + ds.flat_before.projection_angles().value, + ds.sample.projection_angles().value, + ds.flat_after.projection_angles().value, + ds.dark_after.projection_angles().value + ]) + + def test_incomplete_nexus_rotation_angles(self): + ds, _ = _make_standard_dataset() + expected_list = [] + for stack in ds._nexus_stack_order: + expected_list.append(np.zeros(stack.num_images)) + + assert np.array_equal(expected_list, ds.nexus_rotation_angles) + + def test_partially_incomplete_nexus_rotation_angles(self): + ds, _ = _make_standard_dataset() + + _set_fake_projection_angles(ds.dark_before) + _set_fake_projection_angles(ds.flat_before) + _set_fake_projection_angles(ds.dark_after) + expected_list = [ + ds.dark_before.projection_angles().value, + ds.flat_before.projection_angles().value, + np.zeros(ds.sample.num_images), + np.zeros(ds.flat_after.num_images), + ds.dark_after.projection_angles().value + ] + + assert np.array_equal(expected_list, ds.nexus_rotation_angles) + + def test_delete_sample(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[0].id) + self.assertIsNone(ds.sample) + + def test_delete_flat_before(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[1].id) + self.assertIsNone(ds.flat_before) + + def test_delete_flat_after(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[2].id) + self.assertIsNone(ds.flat_after) + + def test_delete_dark_before(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[3].id) + self.assertIsNone(ds.dark_before) + + def test_delete_dark_after(self): + ds, images = _make_standard_dataset() + ds.delete_stack(images[4].id) + self.assertIsNone(ds.dark_after) + + def test_set_stack_by_type_sample(self): + ds = BaseDataset() + sample = mock.Mock() + ds.set_stack(FILE_TYPES.SAMPLE, sample) + + self.assertEqual(ds.sample, sample) + + def test_set_stack_by_type_flat_before(self): + ds = BaseDataset() + stack = mock.Mock() + ds.set_stack(FILE_TYPES.FLAT_BEFORE, stack) + + self.assertEqual(ds.flat_before, stack) + + def test_set_stack_by_type_180(self): + ds = BaseDataset() + sample = mock.Mock() + stack = mock.Mock() + ds.set_stack(FILE_TYPES.SAMPLE, sample) + ds.set_stack(FILE_TYPES.PROJ_180, stack) + + self.assertEqual(ds.proj180deg, stack) + + def test_processed_is_true(self): + ds = BaseDataset(sample=generate_images()) + ds.sample.record_operation("", "") + self.assertTrue(ds.is_processed) + + def test_processed_is_false(self): + ds = BaseDataset(sample=generate_images()) + self.assertFalse(ds.is_processed)