Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify datasets part 4 #2310

Merged
merged 8 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 48 additions & 61 deletions mantidimaging/core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,38 @@ def sinograms(self, sino: ImageStack | None) -> None:

@property
def all(self) -> list[ImageStack]:
return self.recons.stacks + self._stacks + remove_nones([self._sinograms])
named_stacks = [
self.sample, self.proj180deg, self.flat_before, self.flat_after, self.dark_before, self.dark_after,
self.sinograms
]
return self.recons.stacks + self._stacks + remove_nones(named_stacks)

def delete_stack(self, images_id: uuid.UUID) -> None:
for recon in self.recons:
if recon.id == images_id:
self.recons.remove(recon)
return
for image in self._stacks:
if image.id == images_id:
self._stacks.remove(image)
return
if self.sinograms is not None and self.sinograms.id == images_id:
if isinstance(self.sample, ImageStack) and self.sample.id == images_id:
self.sample = None
elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id:
self.flat_before = None
elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id:
self.flat_after = None
elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id:
self.dark_before = None
elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id:
self.dark_after = None
elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id:
assert self.sample is not None
self.sample.clear_proj180deg()
elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id:
self.sinograms = None
return
raise KeyError(_delete_stack_error_message(images_id))
else:
for recon in self.recons:
if recon.id == images_id:
self.recons.remove(recon)
return
for image in self._stacks:
if image.id == images_id:
self._stacks.remove(image)
return
raise KeyError(_delete_stack_error_message(images_id))

def __contains__(self, images_id: uuid.UUID) -> bool:
return any(image.id == images_id for image in self.all)
Expand All @@ -102,20 +119,18 @@ def delete_recons(self) -> None:
def add_stack(self, stack: ImageStack) -> None:
self._stacks.append(stack)


class MixedDataset(BaseDataset):
pass


class StrictDataset(BaseDataset):

@property
def all(self) -> list[ImageStack]:
image_stacks = [
self.sample, self.proj180deg, self.flat_before, self.flat_after, self.dark_before, self.dark_after,
self.sinograms
]
return remove_nones(image_stacks) + self.recons.stacks
def proj180deg(self) -> ImageStack | None:
if self.sample is not None:
return self.sample.proj180deg
else:
return None

@proj180deg.setter
def proj180deg(self, proj180deg: ImageStack | None) -> None:
if self.sample is None:
raise RuntimeError("Can't set a 180 projection without a sample")
self.sample.proj180deg = proj180deg

@property
def _nexus_stack_order(self) -> list[ImageStack]:
Expand Down Expand Up @@ -150,42 +165,6 @@ def image_keys(self) -> list[int]:
image_keys += _image_key_list(2, self.dark_after.data.shape[0])
return image_keys

@property
def proj180deg(self) -> ImageStack | None:
if self.sample is not None:
return self.sample.proj180deg
else:
return None

@proj180deg.setter
def proj180deg(self, proj180deg: ImageStack | None) -> None:
if self.sample is None:
raise RuntimeError("Can't set a 180 projection without a sample")
self.sample.proj180deg = proj180deg

def delete_stack(self, images_id: uuid.UUID) -> None:
if isinstance(self.sample, ImageStack) and self.sample.id == images_id:
self.sample = None # type: ignore
elif isinstance(self.flat_before, ImageStack) and self.flat_before.id == images_id:
self.flat_before = None
elif isinstance(self.flat_after, ImageStack) and self.flat_after.id == images_id:
self.flat_after = None
elif isinstance(self.dark_before, ImageStack) and self.dark_before.id == images_id:
self.dark_before = None
elif isinstance(self.dark_after, ImageStack) and self.dark_after.id == images_id:
self.dark_after = None
elif isinstance(self.proj180deg, ImageStack) and self.proj180deg.id == images_id:
assert self.sample is not None
self.sample.clear_proj180deg()
elif isinstance(self.sinograms, ImageStack) and self.sinograms.id == images_id:
self.sinograms = None
elif images_id in self.recons.ids:
for recon in self.recons:
if recon.id == images_id:
self.recons.remove(recon)
else:
raise KeyError(_delete_stack_error_message(images_id))

def set_stack(self, file_type: FILE_TYPES, image_stack: ImageStack) -> None:
attr_name = file_type.fname.lower().replace(" ", "_")
if file_type == FILE_TYPES.PROJ_180:
Expand All @@ -206,6 +185,14 @@ def is_processed(self) -> bool:
return False


class MixedDataset(BaseDataset):
pass


class StrictDataset(BaseDataset):
pass


def _get_stack_data_type(stack_id: uuid.UUID, dataset: BaseDataset) -> str:
"""
Find the data type as a string of a stack.
Expand Down
199 changes: 199 additions & 0 deletions mantidimaging/core/data/test/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,38 @@
from unittest import mock
import uuid

import numpy as np

from mantidimaging.core.data import ImageStack
from mantidimaging.core.data.dataset import BaseDataset, _get_stack_data_type
from mantidimaging.core.utility.data_containers import ProjectionAngles, FILE_TYPES
from mantidimaging.test_helpers.unit_test_helper import generate_images


def _make_standard_dataset(shape=(2, 5, 5)):
#image_ids = [mock.create_autospec(uuid.UUID) for _ in range(6)]
image_stacks = [generate_images(shape) for _ in range(6)]
image_stacks[0].name = "samplename"

ds = BaseDataset(sample=image_stacks[0],
flat_before=image_stacks[1],
flat_after=image_stacks[2],
dark_before=image_stacks[3],
dark_after=image_stacks[4])
ds.proj180deg = image_stacks[5]
return ds, image_stacks


def _set_fake_projection_angles(image_stack: ImageStack):
"""
Sets the private projection angles attribute.
:param image_stack: The ImageStack object.
"""
image_stack._projection_angles = ProjectionAngles(np.array([0, 180]))
#image_stack.real_projection_angles.return_value = image_stack._projection_angles
#image_stack.projection_angles.return_value = image_stack._projection_angles


class DatasetTest(unittest.TestCase):

def test_create_dataset(self):
Expand Down Expand Up @@ -70,6 +98,17 @@ def test_stacks_in_all(self):
ds = BaseDataset(stacks=image_stacks)
self.assertListEqual(ds.all, image_stacks)

def test_sample_in_all(self):
image_sample = mock.Mock(proj180deg=None)
ds = BaseDataset(sample=image_sample)
self.assertCountEqual(ds.all, [image_sample])

def test_all_for_full_dataset(self):
ds, image_stacks = _make_standard_dataset()
self.assertEqual(len(ds.all), len(image_stacks))
for image in image_stacks:
self.assertIn(image, ds.all)

def test_delete_stack_from_stacks_list(self):
image_stacks = [mock.Mock() for _ in range(3)]
ds = BaseDataset(stacks=image_stacks)
Expand All @@ -89,3 +128,163 @@ def test_get_stack_data_type_returns_images(self):
images_id = images.id
dataset = BaseDataset(stacks=[images])
self.assertEqual(_get_stack_data_type(images_id, dataset), "Images")

def test_attribute_not_set_returns_none(self):
sample = mock.Mock()
dataset = BaseDataset(sample=sample)

self.assertIsNone(dataset.flat_before)
self.assertIsNone(dataset.flat_after)
self.assertIsNone(dataset.dark_before)
self.assertIsNone(dataset.dark_after)

def test_set_flat_before(self):
sample = mock.Mock()
dataset = BaseDataset(sample=sample)
flat_before = mock.Mock(id="1234")
dataset.flat_before = flat_before
self.assertIs(flat_before, dataset.flat_before)
self.assertIn("1234", dataset)

def test_all_images_ids(self):
ds, images = _make_standard_dataset()
self.assertCountEqual(ds.all_image_ids, [image.id for image in images])

def test_nexus_stack_order(self):
ds, _ = _make_standard_dataset()
self.assertListEqual(ds._nexus_stack_order,
[ds.dark_before, ds.flat_before, ds.sample, ds.flat_after, ds.dark_after])

def test_nexus_arrays(self):
ds, _ = _make_standard_dataset()
self.assertListEqual(
ds.nexus_arrays,
[ds.dark_before.data, ds.flat_before.data, ds.sample.data, ds.flat_after.data, ds.dark_after.data])

def test_image_keys(self):
ds, images = _make_standard_dataset()

self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 1, 1, 2, 2])

def test_missing_dark_before_image_keys(self):
ds, images = _make_standard_dataset()
ds.dark_before = None

self.assertListEqual(ds.image_keys, [1, 1, 0, 0, 1, 1, 2, 2])

def test_missing_flat_before_image_keys(self):
ds, images = _make_standard_dataset()
ds.flat_before = None

self.assertListEqual(ds.image_keys, [2, 2, 0, 0, 1, 1, 2, 2])

def test_missing_flat_after_image_keys(self):
ds, images = _make_standard_dataset()
ds.flat_after = None

self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 2, 2])

def test_missing_dark_after_image_keys(self):
ds, images = _make_standard_dataset()
ds.dark_after = None

self.assertListEqual(ds.image_keys, [2, 2, 1, 1, 0, 0, 1, 1])

def test_no_sample_image_keys(self):
ds, images = _make_standard_dataset()
ds.sample = None
with self.assertRaises(RuntimeError):
_ = ds.image_keys

def test_rotation_angles(self):
ds, images = _make_standard_dataset()
for stack in images:
_set_fake_projection_angles(stack)
assert np.array_equal(ds.nexus_rotation_angles, [
ds.dark_before.projection_angles().value,
ds.flat_before.projection_angles().value,
ds.sample.projection_angles().value,
ds.flat_after.projection_angles().value,
ds.dark_after.projection_angles().value
])

def test_incomplete_nexus_rotation_angles(self):
ds, _ = _make_standard_dataset()
expected_list = []
for stack in ds._nexus_stack_order:
expected_list.append(np.zeros(stack.num_images))

assert np.array_equal(expected_list, ds.nexus_rotation_angles)

def test_partially_incomplete_nexus_rotation_angles(self):
ds, _ = _make_standard_dataset()

_set_fake_projection_angles(ds.dark_before)
_set_fake_projection_angles(ds.flat_before)
_set_fake_projection_angles(ds.dark_after)
expected_list = [
ds.dark_before.projection_angles().value,
ds.flat_before.projection_angles().value,
np.zeros(ds.sample.num_images),
np.zeros(ds.flat_after.num_images),
ds.dark_after.projection_angles().value
]

assert np.array_equal(expected_list, ds.nexus_rotation_angles)

def test_delete_sample(self):
ds, images = _make_standard_dataset()
ds.delete_stack(images[0].id)
self.assertIsNone(ds.sample)

def test_delete_flat_before(self):
ds, images = _make_standard_dataset()
ds.delete_stack(images[1].id)
self.assertIsNone(ds.flat_before)

def test_delete_flat_after(self):
ds, images = _make_standard_dataset()
ds.delete_stack(images[2].id)
self.assertIsNone(ds.flat_after)

def test_delete_dark_before(self):
ds, images = _make_standard_dataset()
ds.delete_stack(images[3].id)
self.assertIsNone(ds.dark_before)

def test_delete_dark_after(self):
ds, images = _make_standard_dataset()
ds.delete_stack(images[4].id)
self.assertIsNone(ds.dark_after)

def test_set_stack_by_type_sample(self):
ds = BaseDataset()
sample = mock.Mock()
ds.set_stack(FILE_TYPES.SAMPLE, sample)

self.assertEqual(ds.sample, sample)

def test_set_stack_by_type_flat_before(self):
ds = BaseDataset()
stack = mock.Mock()
ds.set_stack(FILE_TYPES.FLAT_BEFORE, stack)

self.assertEqual(ds.flat_before, stack)

def test_set_stack_by_type_180(self):
ds = BaseDataset()
sample = mock.Mock()
stack = mock.Mock()
ds.set_stack(FILE_TYPES.SAMPLE, sample)
ds.set_stack(FILE_TYPES.PROJ_180, stack)

self.assertEqual(ds.proj180deg, stack)

def test_processed_is_true(self):
ds = BaseDataset(sample=generate_images())
ds.sample.record_operation("", "")
self.assertTrue(ds.is_processed)

def test_processed_is_false(self):
ds = BaseDataset(sample=generate_images())
self.assertFalse(ds.is_processed)