Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changing operations to compute function #2066

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions mantidimaging/core/operations/circular_mask/circular_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Dict, Any

import numpy as np
import tomopy

from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility.qt_helpers import Type

if TYPE_CHECKING:
Expand All @@ -33,24 +34,24 @@ def filter_func(data: ImageStack, circular_mask_ratio=0.95, circular_mask_value=
"""
:param data: Input data as a 3D numpy.ndarray
:param circular_mask_ratio: The ratio to the full image.
The ratio must be 0 < ratio < 1
:param circular_mask_value: The value that all pixels in the mask
will be set to.

@@ -39,20 +41,21 @@ def filter_func(data: ImageStack, circular_mask_ratio=0.95, circular_mask_value=
:return: The processed 3D numpy.ndarray
"""
if not circular_mask_ratio or not circular_mask_ratio < 1:
raise ValueError(f'circular_mask_ratio must be > 0 and < 1. Value provided was {circular_mask_ratio}')

progress = Progress.ensure_instance(progress, num_steps=1, task_name='Circular Mask')
if not 0 < circular_mask_ratio < 1:
raise ValueError(
f"Circular mask ratio must be greater than 0 and less than 1, but value was {circular_mask_ratio}")

with progress:
progress.update(msg="Applying circular mask")
params = {'circular_mask_ratio': circular_mask_ratio, 'circular_mask_value': circular_mask_value}

tomopy.circ_mask(arr=data.data, axis=0, ratio=circular_mask_ratio, val=circular_mask_value)
ps.run_compute_func(CircularMaskFilter.compute_function, data.data.shape[0], [data.shared_array], params,
progress)

return data

@staticmethod
def compute_function(i: int, arrays: List[np.ndarray], params: Dict[str, Any]):
tomopy.circ_mask(arrays[0][i], axis=0, ratio=params['circular_mask_ratio'], val=params['circular_mask_value'])

@staticmethod
def register_gui(form, on_change, view):
from mantidimaging.gui.utility import add_property_to_form
Expand Down
45 changes: 26 additions & 19 deletions mantidimaging/core/operations/clip_values/clip_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Dict, Any

import numpy as np

from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.parallel import shared as ps

if TYPE_CHECKING:
from mantidimaging.core.data import ImageStack
Expand All @@ -25,8 +27,9 @@ class ClipValuesFilter(BaseFilter):
filter_name = "Clip Values"
link_histograms = True

@staticmethod
def filter_func(data,
@classmethod
def filter_func(cls,
data,
clip_min=None,
clip_max=None,
clip_min_new_value=None,
Expand Down Expand Up @@ -54,27 +57,31 @@ def filter_func(data,
"""
# We're using is None because 0.0 is a valid value
if clip_min is None and clip_max is None:
raise ValueError('At least one of clip_min or clip_max must be supplied')
raise ValueError("At least one of clip_min or clip_max must be supplied")

progress = Progress.ensure_instance(progress, num_steps=2, task_name='Clipping Values.')
with progress:
sample = data.data
progress.update(msg="Determining clip min and clip max")
clip_min = clip_min if clip_min is not None else sample.min()
clip_max = clip_max if clip_max is not None else sample.max()
params = {
'clip_min': clip_min,
'clip_max': clip_max,
'clip_min_new_value': clip_min_new_value,
'clip_max_new_value': clip_max_new_value
}

clip_min_new_value = clip_min_new_value if clip_min_new_value is not None else clip_min
ps.run_compute_func(cls.compute_function, data.data.shape[0], [data.shared_array], params, progress)

clip_max_new_value = clip_max_new_value if clip_max_new_value is not None else clip_max
return data

progress.update(msg=f"Clipping data with values min {clip_min} and max {clip_max}")
@staticmethod
def compute_function(i: int, arrays: List[np.ndarray], params: Dict[str, Any]):
array = arrays[0][i]

# this is the fastest way to clip the values, np.clip does not do
# the clipping in place and ends up copying the data
sample[sample < clip_min] = clip_min_new_value
sample[sample > clip_max] = clip_max_new_value
clip_min = params.get('clip_min', np.min(array))
clip_max = params.get('clip_max', np.max(array))
clip_min_new_value = params.get('clip_min_new_value', clip_min)
clip_max_new_value = params.get('clip_max_new_value', clip_max)

return data
np.clip(array, clip_min, clip_max, out=array)
array[array < clip_min] = clip_min_new_value
array[array > clip_max] = clip_max_new_value

@staticmethod
def register_gui(form, on_change, view):
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/core/operations/crop_coords/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# SPDX - License - Identifier: GPL-3.0-or-later
from __future__ import annotations

from .crop_coords import CropCoordinatesFilter, execute_single # noqa:F401
from .crop_coords import CropCoordinatesFilter # noqa:F401

FILTER_CLASS = CropCoordinatesFilter
40 changes: 18 additions & 22 deletions mantidimaging/core/operations/crop_coords/crop_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

from functools import partial
from typing import Union, Optional, List, TYPE_CHECKING
from typing import Union, Optional, List, TYPE_CHECKING, Dict, Any

import numpy as np

from mantidimaging import helper as h
from mantidimaging.core.parallel import utility as pu, shared as ps
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
from mantidimaging.core.parallel import utility as pu
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.utility.sensible_roi import SensibleROI
from mantidimaging.gui.utility.qt_helpers import Type

Expand Down Expand Up @@ -69,11 +70,24 @@ def filter_func(images: ImageStack,
"This can happen on the image preview right after a previous Crop Coordinates.")

output = pu.create_array(shape, images.dtype)
execute_single(sample, region_of_interest, progress, out=output.array)
params = {'sample': sample, 'roi': region_of_interest, 'output': output.array}
ps.run_compute_func(CropCoordinatesFilter.compute_function, sample.shape[0], images.shared_array, params,
progress)
images.shared_array = output
return images

@staticmethod
def compute_function(i: int, array: np.ndarray, params: Dict[str, Any]):
_ = array
sample = params['sample']
roi = params['roi']
output = params['output']
if isinstance(roi, SensibleROI):
left, top, right, bottom = roi.left, roi.top, roi.right, roi.bottom
else:
left, top, right, bottom = roi[0], roi[1], roi[2], roi[3]
output[i] = sample[i, top:bottom, left:right]

def register_gui(form, on_change, view):
from mantidimaging.gui.utility import add_property_to_form
label, roi_field = add_property_to_form("ROI",
Expand All @@ -97,21 +111,3 @@ def execute_wrapper(roi_field: QLineEdit) -> partial:
@staticmethod
def group_name() -> FilterGroup:
return FilterGroup.Basic


def execute_single(data, roi, progress=None, out=None):
progress = Progress.ensure_instance(progress, task_name='Crop Coords')

if roi:
progress.add_estimated_steps(1)

with progress:
assert all(isinstance(region, int) for
region in roi), \
"The region of interest coordinates are not integers!"

progress.update(msg="Cropping with coordinates: {0}".format(roi))

output = out[:] if out is not None else data[:]
output[:] = data[:, roi.top:roi.bottom, roi.left:roi.right]
return output
28 changes: 16 additions & 12 deletions mantidimaging/core/operations/divide/divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from functools import partial
from typing import Union, Callable, Dict, Any, TYPE_CHECKING

from mantidimaging import helper as h
import numpy as np

from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.gui.utility.qt_helpers import Type

Expand Down Expand Up @@ -36,16 +38,24 @@ def filter_func(images: ImageStack, value: Union[int, float] = 0, unit="micron",

:return: The ImageStack object which has been divided by a value.
"""
h.check_data_stack(images)
if not value:
raise ValueError('value parameter must not equal 0 or None')
if value == 0:
raise ValueError('value parameter must not equal 0')

# Convert microns to cm if necessary
if unit == "micron":
value *= 1e-4
conversion_factor = 1e-4 # Example conversion factor
value *= conversion_factor

params = {'value': value}
ps.run_compute_func(DivideFilter.compute_function, images.data.shape[0], images.shared_array, params, progress)

images.data /= value
return images

@staticmethod
def compute_function(i: int, array: np.ndarray, params: dict):
value = params['value']
array[i] /= value

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BasePresenter') -> Dict[str, Any]:
from mantidimaging.gui.utility import add_property_to_form
Expand Down Expand Up @@ -75,9 +85,3 @@ def execute_wrapper( # type: ignore
value = value_widget.value()
unit = unit_widget.currentText()
return partial(DivideFilter.filter_func, value=value, unit=unit)

@staticmethod
def validate_execute_kwargs(kwargs: Dict[str, Any]) -> bool:
if 'value_widget' not in kwargs:
return False
return True
59 changes: 20 additions & 39 deletions mantidimaging/core/operations/nan_removal/nan_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from __future__ import annotations

from functools import partial
from logging import getLogger
from typing import Dict, TYPE_CHECKING

import numpy as np
import scipy.ndimage as scipy_ndimage
from tomopy import median_filter

from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility.qt_helpers import Type

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,16 +49,28 @@ def filter_func(data, replace_value=None, mode_value="Constant", progress=None)
:return: The ImageStack object with the NaNs replaced.
"""

params = {'replace_value': replace_value, 'mode_value': mode_value}
ps.run_compute_func(NaNRemovalFilter.compute_function, data.data.shape[0], data.shared_array, params, progress)

return data

@staticmethod
def compute_function(i: int, array: np.ndarray, params: dict):
mode_value = params['mode_value']
replace_value = params['replace_value']
if mode_value == "Constant":
sample = data.data
nan_idxs = np.isnan(sample)
sample[nan_idxs] = replace_value
nan_idxs = np.isnan(array[i])
array[i][nan_idxs] = replace_value
elif mode_value == "Median":
_execute(data, 3, "reflect", progress)
nans = np.isnan(array[i])
if np.any(nans):
median_data = np.where(nans, -np.inf, array[i])
median_data = median_filter(median_data, size=3, mode='reflect')
array[i] = np.where(nans, median_data, array[i])
# Convert infs back to NaNs
array[i] = np.where(np.logical_and(nans, array[i] == -np.inf), np.nan, array[i])
else:
raise ValueError(f"Unknown mode: '{mode_value}'\nShould be one of {NaNRemovalFilter.MODES}")

return data
raise ValueError(f"Unknown mode: '{mode_value}'. Should be one of {NaNRemovalFilter.MODES}")

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindowView') -> Dict[str, 'QWidget']:
Expand Down Expand Up @@ -92,32 +102,3 @@ def execute_wrapper(mode_field=None, replace_value_field=None):
mode_value = mode_field.currentText()
replace_value = replace_value_field.value()
return partial(NaNRemovalFilter.filter_func, replace_value=replace_value, mode_value=mode_value)


def _nan_to_median(data: np.ndarray, size: int, edgemode: str):
nans = np.isnan(data)
if np.any(nans):
median_data = np.where(nans, -np.inf, data)
median_data = scipy_ndimage.median_filter(median_data, size=size, mode=edgemode)
data = np.where(nans, median_data, data)

if np.any(data == -np.inf):
# Convert any left over -infs back to NaNs
data = np.where(np.logical_and(nans, data == -np.inf), np.nan, data)

return data


def _execute(images: ImageStack, size, edgemode, progress=None):
log = getLogger(__name__)
progress = Progress.ensure_instance(progress, task_name='NaN Removal')

# create the partial function to forward the parameters
f = ps.create_partial(_nan_to_median, ps.return_to_self, size=size, edgemode=edgemode)

with progress:
log.info("PARALLEL NaN Removal filter, with pixel data type: {0}".format(images.dtype))

ps.execute(f, [images.shared_array], images.data.shape[0], progress, msg="NaN Removal")

return images
Loading
Loading