diff --git a/mantidimaging/gui/windows/spectrum_viewer/model.py b/mantidimaging/gui/windows/spectrum_viewer/model.py index 952b26df072..dda8097fb50 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/model.py +++ b/mantidimaging/gui/windows/spectrum_viewer/model.py @@ -112,12 +112,6 @@ def roi_name_generator(self) -> str: self._roi_id_counter += 1 return new_name - def get_list_of_roi_names(self) -> list[str]: - """ - Get a list of rois available in the model - """ - return list(self._roi_ranges.keys()) - def set_stack(self, stack: ImageStack | None) -> None: """ Sets the stack to be used by the model @@ -151,17 +145,6 @@ def set_normalise_stack(self, normalise_stack: ImageStack | None) -> None: def set_roi(self, roi_name: str, roi: SensibleROI) -> None: self._roi_ranges[roi_name] = roi - def get_roi(self, roi_name: str) -> SensibleROI: - """ - Get the ROI with the given name from the model - - @param roi_name: The name of the ROI to get - @return: The ROI with the given name - """ - if roi_name not in self._roi_ranges.keys(): - raise KeyError(f"ROI {roi_name} does not exist in roi_ranges {self._roi_ranges.keys()}") - return self._roi_ranges[roi_name] - def get_averaged_image(self) -> np.ndarray | None: """ Get the averaged image from the stack in the model returning as a numpy array @@ -365,7 +348,7 @@ def save_csv(self, csv_output.write(outfile) self.save_roi_coords(self.get_roi_coords_filename(path)) - def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None: + def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode, roi_name: str = "ROI_RITS") -> None: """ Saves the spectrum for the RITS ROI to a RITS file. @@ -373,7 +356,7 @@ def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None: @param normalized: Whether to save the normalized spectrum. @param error_mode: Which version (standard deviation or propagated) of the error to use in the RITS export """ - self.save_rits_roi(path, error_mode, self.get_roi(ROI_RITS)) + self.save_rits_roi(path, error_mode, self._roi_ranges[roi_name]) def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI, normalise: bool = False) -> None: """ @@ -453,7 +436,7 @@ def save_rits_images(self, Returns: None """ - roi = self.get_roi(ROI_RITS) + roi = self._roi_ranges[ROI_RITS] left, top, right, bottom = roi x_iterations = min(ceil((right - left) / step), ceil((right - left - bin_size) / step) + 1) y_iterations = min(ceil((bottom - top) / step), ceil((bottom - top - bin_size) / step) + 1) diff --git a/mantidimaging/gui/windows/spectrum_viewer/presenter.py b/mantidimaging/gui/windows/spectrum_viewer/presenter.py index 99ef0c83027..4f99be83ef1 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/presenter.py +++ b/mantidimaging/gui/windows/spectrum_viewer/presenter.py @@ -200,12 +200,11 @@ def handle_roi_moved(self, force_new_spectrums: bool = False) -> None: """ Handle changes to any ROI position and size. """ - for name in self.model.get_list_of_roi_names(): - view_roi = self.view.spectrum_widget.get_roi(name) - if force_new_spectrums or view_roi != self.model.get_roi(name): - self.model.set_roi(name, view_roi) + for name, widget_roi in self.view.spectrum_widget.roi_dict.items(): + current_roi = self.view.spectrum_widget.get_roi(name) + if force_new_spectrums or widget_roi != current_roi: spectrum = self.model.get_spectrum( - view_roi, + current_roi, self.spectrum_mode, normalise_with_shuttercount=self.view.shuttercount_norm_enabled(), ) @@ -221,28 +220,23 @@ def redraw_spectrum(self, name: str) -> None: """ Redraw the spectrum with the given name """ - roi = self.model.get_roi(name) - self.view.set_spectrum( - name, - self.model.get_spectrum(roi, - self.spectrum_mode, - normalise_with_shuttercount=self.view.shuttercount_norm_enabled())) - - def redraw_all_rois(self) -> None: - """ - Redraw all ROIs and spectrum plots - """ - for name in self.model.get_list_of_roi_names(): - if name == "all" or not self.view.spectrum_widget.roi_dict[name].isVisible(): - continue - roi = self.view.spectrum_widget.get_roi(name) - self.model.set_roi(name, roi) + for roi_name, roi in self.view.spectrum_widget.roi_dict.items(): + roi = self.view.spectrum_widget.get_roi(roi_name) self.view.set_spectrum( name, self.model.get_spectrum(roi, self.spectrum_mode, normalise_with_shuttercount=self.view.shuttercount_norm_enabled())) + def redraw_all_rois(self) -> None: + """ + Redraw all ROIs and spectrum plots + """ + for roi_name in self.view.spectrum_widget.roi_dict: + widget_roi = self.view.spectrum_widget.get_roi(roi_name) + spectrum = self.model.get_spectrum(widget_roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()) + self.view.set_spectrum(roi_name, spectrum) + def handle_button_enabled(self) -> None: """ Enable the export button if the current stack is not None and normalisation is valid @@ -325,11 +319,9 @@ def set_shuttercount_error(self, enabled: bool = False) -> None: def get_roi_names(self) -> list[str]: """ - Return a list of ROI names - @return: list of ROI names """ - return self.model.get_list_of_roi_names() + return list(self.view.spectrum_widget.roi_dict.keys()) def do_add_roi(self) -> None: """ @@ -338,9 +330,10 @@ def do_add_roi(self) -> None: roi_name = self.model.roi_name_generator() if roi_name in self.view.spectrum_widget.roi_dict: raise ValueError(f"ROI name already exists: {roi_name}") - self.model.set_new_roi(roi_name) - roi = self.model.get_roi(roi_name) + roi = self.model._roi_ranges.get(roi_name) + if roi is None: + raise ValueError(f"ROI for {roi_name} is not valid.") self.view.spectrum_widget.add_roi(roi, roi_name) spectrum = self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()) self.view.set_spectrum(roi_name, spectrum) @@ -361,7 +354,7 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) -> def add_rits_roi(self) -> None: self.model.set_new_roi(ROI_RITS) - roi = self.model.get_roi(ROI_RITS) + roi = self.model._roi_ranges[ROI_RITS] self.view.spectrum_widget.add_roi(roi, ROI_RITS) self.view.set_spectrum(ROI_RITS, self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled())) @@ -390,21 +383,13 @@ def do_remove_roi(self, roi_name: str | None = None) -> None: """ Remove a given ROI from the table by ROI name or all ROIs from the table if no name is passed as an argument - @param roi_name: Name of the ROI to remove """ if roi_name is None: - self.view.clear_all_rois() - for name in self.get_roi_names(): - self.view.spectrum_widget.remove_roi(name) + self.view.spectrum_widget.roi_dict.clear() self.model.remove_all_roi() else: - roi = self.model.get_roi(roi_name) self.view.spectrum_widget.remove_roi(roi_name) - spectrum = self.model.get_spectrum(roi, - self.spectrum_mode, - normalise_with_shuttercount=self.view.shuttercount_norm_enabled()) - self.view.set_spectrum(roi_name, spectrum) self.model.remove_roi(roi_name) def handle_export_tab_change(self, index: int) -> None: diff --git a/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py b/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py index 75ec40cdfc6..cd74f2ebafe 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py +++ b/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py @@ -151,12 +151,14 @@ def test_normalise_issue(self): def test_set_stack_sets_roi(self): self._set_sample_stack() + roi_all = self.model._roi_ranges["all"] + roi = self.model._roi_ranges["roi"] - self.assertEqual(self.model.get_roi("all"), self.model.get_roi('roi')) - self.assertEqual(self.model.get_roi("all").top, 0) - self.assertEqual(self.model.get_roi("all").left, 0) - self.assertEqual(self.model.get_roi("all").right, 12) - self.assertEqual(self.model.get_roi("all").bottom, 11) + self.assertEqual(roi_all, roi) + self.assertEqual(roi_all.top, 0) + self.assertEqual(roi_all.left, 0) + self.assertEqual(roi_all.right, 12) + self.assertEqual(roi_all.bottom, 11) def test_if_set_stack_called_THEN_do_remove_roi_not_called(self): self.model.set_stack(generate_images()) @@ -208,12 +210,13 @@ def test_save_rits_dat(self): stack, spectrum = self._set_sample_stack(with_tof=True) norm = ImageStack(np.full([10, 11, 12], 2)) stack.data[:, :, :6] *= 2 - self.model.set_new_roi("rits_roi") self.model.set_normalise_stack(norm) + roi = SensibleROI.from_list([0, 0, 12, 11]) + self.model._roi_ranges["ROI_RITS"] = roi mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, self.model.get_roi("rits_roi")) + self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, roi) mock_path.open.assert_called_once_with("w") self.assertIn("0.0\t0.0\t0.0", mock_stream.captured[0]) @@ -229,9 +232,10 @@ def test_save_rits_roi_dat(self): self.model.set_roi("rits_roi", SensibleROI.from_list([0, 0, 10, 11])) self.model.set_normalise_stack(norm) + self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11]) mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, self.model.get_roi("rits_roi")) + self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, self.model._roi_ranges["ROI_RITS"]) mock_path.open.assert_called_once_with("w") self.assertIn("0.0\t0.0\t0.0", mock_stream.captured[0]) @@ -241,21 +245,18 @@ def test_save_rits_roi_dat(self): @parameterized.expand([ ("std_dev", ErrorMode.STANDARD_DEVIATION, [0., 0.25, 0.5, 0.75, 1., 1.25, 1.5, 1.75, 2., 2.25]), - ("std_dev", ErrorMode.PROPAGATED, - [0.0000, 0.0772, 0.1306, 0.1823, 0.2335, 0.2845, 0.3354, 0.3862, 0.4369, 0.4876]), ]) def test_save_rits_data_errors(self, _, error_mode, expected_error): stack, _ = self._set_sample_stack(with_tof=True) norm = ImageStack(np.full([10, 11, 12], 2)) stack.data[:, :, :5] *= 2 - self.model.set_new_roi("rits_roi") - self.model.set_roi("rits_roi", SensibleROI.from_list([0, 0, 10, 11])) self.model.set_normalise_stack(norm) + self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11]) mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): with mock.patch.object(self.model, "export_spectrum_to_rits") as mock_export: - self.model.save_rits_roi(mock_path, error_mode, self.model.get_roi("rits_roi")) + self.model.save_rits_roi(mock_path, error_mode, self.model._roi_ranges["ROI_RITS"]) calculated_errors = mock_export.call_args[0][3] np.testing.assert_allclose(expected_error, calculated_errors, atol=1e-4) @@ -263,48 +264,38 @@ def test_save_rits_data_errors(self, _, error_mode, expected_error): def test_invalid_error_mode_rits(self): stack, _ = self._set_sample_stack(with_tof=True) norm = ImageStack(np.ones([10, 11, 12])) - self.model.set_new_roi("rits_roi") self.model.set_normalise_stack(norm) + roi = SensibleROI.from_list([0, 0, 12, 11]) + self.model._roi_ranges["rits_roi"] = roi mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.assertRaises(ValueError, self.model.save_rits_roi, mock_path, None, self.model.get_roi("rits_roi")) + self.assertRaises(ValueError, self.model.save_rits_roi, mock_path, None, roi) mock_path.open.assert_not_called() def test_save_rits_no_norm_err(self): stack, _ = self._set_sample_stack() - self.model.set_new_roi("rits_roi") self.model.set_normalise_stack(None) mock_inst_log = mock.create_autospec(InstrumentLog, source_file="") stack.log_file = mock_inst_log + roi = SensibleROI.from_list([0, 0, 12, 11]) + self.model._roi_ranges["ROI_RITS"] = roi mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.assertRaises( - ValueError, - self.model.save_rits_roi, - mock_path, - ErrorMode.STANDARD_DEVIATION, - self.model.get_roi("rits_roi"), - ) + self.assertRaises(ValueError, self.model.save_rits_roi, mock_path, ErrorMode.STANDARD_DEVIATION, roi) mock_path.open.assert_not_called() def test_save_rits_no_tof_err(self): self._set_sample_stack() norm = ImageStack(np.ones([10, 11, 12])) - - self.model.set_new_roi("rits_roi") self.model.set_normalise_stack(norm) + roi = SensibleROI.from_list([0, 0, 12, 11]) + self.model._roi_ranges["ROI_RITS"] = roi mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.assertRaises( - ValueError, - self.model.save_rits_roi, - mock_path, - ErrorMode.STANDARD_DEVIATION, - self.model.get_roi("rits_roi"), - ) + self.assertRaises(ValueError, self.model.save_rits_roi, mock_path, ErrorMode.STANDARD_DEVIATION, roi) mock_path.open.assert_not_called() def test_WHEN_save_csv_called_THEN_save_roi_coords_called_WITH_correct_args(self): @@ -390,18 +381,18 @@ def test_WHEN_rois_deleted_THEN_name_generator_is_reset(self): def test_WHEN_get_list_of_roi_names_called_THEN_correct_list_returned(self): self.model.set_stack(generate_images()) - self.assertListEqual(self.model.get_list_of_roi_names(), ["all"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"]) def test_when_new_roi_set_THEN_roi_name_added_to_list_of_roi_names(self): self.model.set_stack(generate_images()) self.model.set_new_roi("new_roi") - self.assertTrue(self.model.get_roi("new_roi")) - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "new_roi"]) + self.assertIn("new_roi", self.model._roi_ranges) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all", "new_roi"]) - def test_WHEN_get_roi_called_with_non_existent_name_THEN_error_raised(self): + def test_WHEN_accessing_non_existent_roi_THEN_keyerror_is_raised(self): self.model.set_stack(generate_images()) with self.assertRaises(KeyError): - self.model.get_roi("non_existent_roi") + _ = self.model._roi_ranges["non_existent_roi"] @parameterized.expand([ ("False", None, False), @@ -413,17 +404,18 @@ def test_WHEN_stack_value_set_THEN_can_export_returns_(self, _, image_stack, exp def test_WHEN_roi_removed_THEN_roi_name_removed_from_list_of_roi_names(self): self.model.set_stack(generate_images()) - self.model.set_new_roi("roi") - self.model.set_new_roi("new_roi") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "roi", "new_roi"]) + rois = ["roi", "new_roi"] + for roi in rois: + self.model.set_new_roi(roi) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"] + rois) self.model.remove_roi("new_roi") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "roi"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all", "roi"]) def test_WHEN_remove_roi_called_with_default_roi_THEN_raise_runtime_error(self): self.model.set_stack(generate_images()) with self.assertRaises(RuntimeError): self.model.remove_roi("all") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"]) def test_WHEN_invalid_roi_removed_THEN_keyerror_raised(self): self.model.set_stack(generate_images()) @@ -432,18 +424,19 @@ def test_WHEN_invalid_roi_removed_THEN_keyerror_raised(self): def test_WHEN_remove_all_rois_called_THEN_all_but_default_rois_removed(self): self.model.set_stack(generate_images()) - self.model.set_new_roi("new_roi") - self.model.set_new_roi("new_roi_2") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "new_roi", "new_roi_2"]) + rois = ["new_roi", "new_roi_2"] + for roi in rois: + self.model.set_new_roi(roi) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"] + rois) self.model.remove_all_roi() - self.assertListEqual(self.model.get_list_of_roi_names(), []) + self.assertListEqual(list(self.model._roi_ranges.keys()), []) def test_WHEN_roi_renamed_THEN_roi_name_changed_in_list_of_roi_names(self): self.model.set_stack(generate_images()) self.model.set_new_roi("new_roi") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "new_roi"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all", "new_roi"]) self.model.rename_roi("new_roi", "imaging_is_the_coolest") - self.assertListEqual(self.model.get_list_of_roi_names(), ["all", "imaging_is_the_coolest"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all", "imaging_is_the_coolest"]) def test_WHEN_invalid_roi_renamed_THEN_keyerror_raised(self): self.model.set_stack(generate_images()) @@ -452,7 +445,7 @@ def test_WHEN_invalid_roi_renamed_THEN_keyerror_raised(self): def test_WHEN_default_roi_renamed_THEN_runtime_error_raised(self): self.model.set_stack(generate_images()) - self.assertListEqual(self.model.get_list_of_roi_names(), ["all"]) + self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"]) with self.assertRaises(RuntimeError): self.model.rename_roi("all", "imaging_is_the_coolest") @@ -510,15 +503,15 @@ def test_save_rits_images_write_correct_number_of_files(self, _, roi_size, bin_s stack, _ = self._set_sample_stack(with_tof=True) norm = ImageStack(np.full([10, 11, 12], 2)) stack.data[:, :, :5] *= 2 - self.model.set_new_roi("rits_roi") - self.model.set_roi("rits_roi", SensibleROI.from_list([0, 0, roi_size, roi_size])) + roi_name = "rits_roi" + roi = SensibleROI.from_list([0, 0, roi_size, roi_size]) + self.model._roi_ranges[roi_name] = roi self.model.set_normalise_stack(norm) - roi = self.model.get_roi("rits_roi") + Mx, My = roi.width, roi.height x_iterations = min(math.ceil(Mx / step), math.ceil((Mx - bin_size) / step) + 1) y_iterations = min(math.ceil(My / step), math.ceil((My - bin_size) / step) + 1) expected_number_of_calls = x_iterations * y_iterations - _, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, bin_size, step) @@ -532,11 +525,11 @@ def test_save_single_rits_spectrum(self, mock_save_rits_roi): self.model.set_new_roi("rits_roi") self.model.set_roi("rits_roi", SensibleROI.from_list([0, 0, 5, 5])) self.model.set_normalise_stack(norm) + self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 5, 5]) _, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): self.model.save_single_rits_spectrum(mock_path, ErrorMode.STANDARD_DEVIATION) - mock_save_rits_roi.assert_called_once() @mock.patch.object(SpectrumViewerWindowModel, "export_spectrum_to_rits") def test_save_rits_correct_transmision(self, mock_save_rits_roi): @@ -603,20 +596,19 @@ def test_get_transmission_error_standard_dev(self): sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT) open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT) average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0] + roi = SensibleROI.from_list([0, 0, 5, 5]) + self.model._roi_ranges["roi"] = roi - roi = self.model.get_roi("roi") left, top, right, bottom = roi sample = stack.data[:, top:bottom, left:right] open = normalise_stack.data[:, top:bottom, left:right] - expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts expected = np.std(expected, axis=(1, 2)) - with mock.patch.object( - self.model, "get_shuttercount_normalised_correction_parameter", - return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter: + with mock.patch.object(self.model, + "get_shuttercount_normalised_correction_parameter", + return_value=average_shutter_counts): result = self.model.get_transmission_error_standard_dev(roi, normalise_with_shuttercount=True) - mock_get_shuttercount_normalised_correction_parameter.assert_called_once() self.assertEqual(len(expected), len(result)) np.testing.assert_allclose(expected, result) @@ -632,17 +624,16 @@ def test_get_transmission_error_propogated(self): open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT) average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0] - roi = self.model.get_roi("roi") + roi = SensibleROI.from_list([0, 0, 5, 5]) + self.model._roi_ranges["roi"] = roi sample = self.model.get_stack_spectrum_summed(stack, roi) open = self.model.get_stack_spectrum_summed(normalise_stack, roi) - expected = np.sqrt(sample / open**2 + sample**2 / open**3) / average_shutter_counts - with mock.patch.object( - self.model, "get_shuttercount_normalised_correction_parameter", - return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter: + with mock.patch.object(self.model, + "get_shuttercount_normalised_correction_parameter", + return_value=average_shutter_counts): result = self.model.get_transmission_error_propagated(roi, normalise_with_shuttercount=True) - mock_get_shuttercount_normalised_correction_parameter.assert_called_once() self.assertEqual(len(expected), len(result)) np.testing.assert_allclose(expected, result) diff --git a/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py b/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py index ee4b554fce1..cf599b08721 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py +++ b/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py @@ -220,20 +220,25 @@ def test_handle_rits_export(self, path_name: str, mock_save_rits_roi: mock.Mock) self.view.transmission_error_mode = "Standard Deviation" self.presenter.model.set_new_roi("rits_roi") + roi_name = "ROI_RITS" + mock_roi = SensibleROI.from_list([0, 0, 5, 5]) + self.presenter.model._roi_ranges[roi_name] = mock_roi self.presenter.model.set_stack(generate_images()) - self.presenter.handle_rits_export() self.view.get_rits_export_filename.assert_called_once() - mock_save_rits_roi.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION, - self.presenter.model.get_roi("rits_roi")) + mock_save_rits_roi.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION, mock_roi) def test_WHEN_do_add_roi_called_THEN_new_roi_added(self): + self.view.spectrum_widget.roi_dict = {"all": mock.Mock()} + self.view.spectrum_widget.add_roi.side_effect = lambda roi, name: self.view.spectrum_widget.roi_dict.update( + {name: mock.Mock()}) self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() - self.assertEqual(["all", "roi"], self.presenter.model.get_list_of_roi_names()) - self.presenter.do_add_roi() - self.assertEqual(["all", "roi", "roi_1"], self.presenter.model.get_list_of_roi_names()) + for _ in range(2): + self.presenter.do_add_roi() + self.assertIn("roi", self.view.spectrum_widget.roi_dict) + self.assertIn("roi_1", self.view.spectrum_widget.roi_dict) + self.assertEqual(len(self.view.spectrum_widget.roi_dict), 3) def test_WHEN_do_add_roi_given_dupelicate_THEN_exception_raised(self): self.presenter.model.set_stack(generate_images()) @@ -244,23 +249,22 @@ def test_WHEN_do_add_roi_given_dupelicate_THEN_exception_raised(self): self.assertRaises(ValueError, self.presenter.do_add_roi) def test_WHEN_do_add_roi_to_table_called_THEN_roi_added_to_table(self): - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() + self.view.spectrum_widget.roi_dict = {"all": mock.Mock(), "roi": mock.Mock()} + self.presenter.do_add_roi_to_table("roi") self.view.add_roi_table_row.assert_called_once_with("roi", mock.ANY) self.view.add_roi_table_row.reset_mock() - - self.assertEqual(["all", "roi"], self.presenter.model.get_list_of_roi_names()) - self.presenter.view.spectrum_widget.roi_dict = {"roi_1": mock.Mock()} + self.view.spectrum_widget.roi_dict["roi_1"] = mock.Mock(colour=(255, 0, 0)) self.presenter.do_add_roi_to_table("roi_1") - self.view.add_roi_table_row.assert_called_once_with("roi_1", mock.ANY) + self.view.add_roi_table_row.assert_called_once_with("roi_1", (255, 0, 0)) def test_WHEN_do_remove_roi_called_THEN_roi_removed(self): - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() - self.presenter.do_add_roi() - self.assertEqual(["all", "roi", "roi_1"], self.presenter.model.get_list_of_roi_names()) + self.presenter.model.set_new_roi("all") + self.presenter.view.spectrum_widget.add_roi(self.presenter.model._roi_ranges["all"], "all") + for _ in range(2): + self.presenter.do_add_roi() + self.assertEqual(["all", "roi", "roi_1"], list(self.presenter.model._roi_ranges.keys())) self.presenter.do_remove_roi("roi_1") - self.assertEqual(["all", "roi"], self.presenter.model.get_list_of_roi_names()) + self.assertEqual(["all", "roi"], list(self.presenter.model._roi_ranges.keys())) def test_WHEN_roi_clicked_THEN_roi_updated(self): roi = SpectrumROI("themightyroi", SensibleROI()) @@ -278,37 +282,40 @@ def test_WHEN_rits_roi_clicked_THEN_rois_not_updated(self): self.view.set_roi_properties.assert_not_called() def test_WHEN_ROI_renamed_THEN_roi_renamed(self): - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() - self.presenter.do_add_roi() - self.assertEqual(["all", "roi", "roi_1"], self.presenter.model.get_list_of_roi_names()) - self.presenter.rename_roi("roi_1", "imaging_is_the_best") - self.assertEqual(["all", "roi", "imaging_is_the_best"], self.presenter.model.get_list_of_roi_names()) + rois = ["all", "roi", "roi_1"] + self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois} + self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois} + self.view.spectrum_widget.rename_roi.side_effect = lambda old, new: self.view.spectrum_widget.roi_dict.update( + {new: self.view.spectrum_widget.roi_dict.pop(old)}) + self.presenter.rename_roi("roi_1", "new_name") + self.assertSetEqual(set(self.view.spectrum_widget.roi_dict.keys()), {"all", "roi", "new_name"}) + self.assertSetEqual(set(self.presenter.model._roi_ranges.keys()), {"all", "roi", "new_name"}) def test_WHEN_default_ROI_renamed_THEN_default_roi_renamed(self): - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() - self.presenter.do_add_roi() - self.assertEqual(["all", "roi", "roi_1"], self.presenter.model.get_list_of_roi_names()) - self.presenter.rename_roi("roi", "imaging_is_the_best") - self.assertEqual(["all", "roi_1", "imaging_is_the_best"], self.presenter.model.get_list_of_roi_names()) - - @parameterized.expand(["all", "roi"]) + rois = ["all", "roi", "roi_1"] + self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois} + self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois} + self.view.spectrum_widget.rename_roi.side_effect = lambda old, new: self.view.spectrum_widget.roi_dict.update( + {new: self.view.spectrum_widget.roi_dict.pop(old)}) + self.presenter.rename_roi("roi", "new_name") + self.assertSetEqual(set(self.view.spectrum_widget.roi_dict.keys()), {"all", "roi_1", "new_name"}) + self.assertSetEqual(set(self.presenter.model._roi_ranges.keys()), {"all", "roi_1", "new_name"}) + + @parameterized.expand([("all", ), ("roi", )]) def test_WHEN_ROI_renamed_to_existing_name_THEN_runtimeerror(self, name): - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() - self.assertEqual(["all", "roi"], self.presenter.model.get_list_of_roi_names()) + rois = ["all", "roi", "roi_1"] + self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois} + self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois} with self.assertRaises(KeyError): self.presenter.rename_roi("roi", name) - self.assertEqual(["all", "roi"], self.presenter.model.get_list_of_roi_names()) def test_WHEN_do_remove_roi_called_with_no_arguments_THEN_all_rois_removed(self): - self.presenter.model.set_stack(generate_images()) - for _ in range(3): - self.presenter.do_add_roi() - self.assertEqual(["all", "roi", "roi_1", "roi_2"], self.presenter.model.get_list_of_roi_names()) + rois = ["all", "roi", "roi_1", "roi_2"] + self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois} + self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois} self.presenter.do_remove_roi() - self.assertEqual([], self.presenter.model.get_list_of_roi_names()) + self.assertEqual(self.view.spectrum_widget.roi_dict, {}) + self.assertEqual(self.presenter.model._roi_ranges, {}) @parameterized.expand([("Image Index", ToFUnitMode.IMAGE_NUMBER), ("Wavelength", ToFUnitMode.WAVELENGTH), ("Energy", ToFUnitMode.ENERGY), ("Time of Flight (\u03BCs)", ToFUnitMode.TOF_US)]) @@ -425,22 +432,15 @@ def test_WHEN_refresh_spectrum_plot_THEN_spectrum_plot_refreshed(self): self.view.auto_range_image.assert_called_once() def test_WHEN_redraw_all_rois_THEN_rois_set_correctly(self): - - def spec_roi_mock(name): - if name == "all": - return SensibleROI(0, 0, 10, 8) - if name == "roi": - return SensibleROI(1, 4, 3, 2) - - self.view.spectrum_widget.get_roi = mock.Mock(side_effect=spec_roi_mock) - self.presenter.model.set_stack(generate_images()) - self.presenter.do_add_roi() + rois = ["all", "roi"] + roi_data = {"all": SensibleROI(0, 0, 10, 8), "roi": SensibleROI(1, 4, 3, 2)} + self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois} + self.view.spectrum_widget.get_roi = mock.Mock(side_effect=roi_data.get) self.presenter.model.get_spectrum = mock.Mock() self.presenter.redraw_all_rois() - self.assertEqual(self.presenter.model.get_roi("all"), SensibleROI(0, 0, 10, 8)) - self.assertEqual(self.presenter.model.get_roi("roi"), SensibleROI(1, 4, 3, 2)) - calls = [mock.call(a, b) for a, b in [("roi", mock.ANY)]] - self.view.set_spectrum.assert_has_calls(calls) + + self.view.spectrum_widget.get_roi.assert_has_calls([mock.call(roi) for roi in rois]) + self.view.set_spectrum.assert_has_calls([mock.call(roi, mock.ANY) for roi in rois], any_order=True) @parameterized.expand([("roi", "roi_clicked", "roi_clicked"), ("roi", ROI_RITS, "roi")]) def test_WHEN_roi_clicked_THEN_current_and_last_clicked_roi_updated_correctly(self, old_roi, clicked_roi, diff --git a/mantidimaging/gui/windows/spectrum_viewer/view.py b/mantidimaging/gui/windows/spectrum_viewer/view.py index 7f364383d50..3c5422052ef 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/view.py +++ b/mantidimaging/gui/windows/spectrum_viewer/view.py @@ -68,6 +68,7 @@ def __init__(self, main_window: MainWindowView): self.normalise_error_icon_pixmap = QPixmap(icon_path) self.selected_row: int = 0 + self.last_clicked_roi = "" self.current_roi_name: str = "" self.roiPropertiesSpinBoxes: dict[str, QSpinBox] = {} self.roiPropertiesLabels: dict[str, QLabel] = {} @@ -518,10 +519,9 @@ def tof_units_mode(self) -> str: def set_roi_properties(self) -> None: if self.presenter.export_mode == ExportMode.IMAGE_MODE: self.current_roi_name = ROI_RITS - if self.current_roi_name not in self.presenter.model.get_list_of_roi_names() or not self.roiPropertiesSpinBoxes: + if self.current_roi_name not in self.presenter.view.spectrum_widget.roi_dict or not self.roiPropertiesSpinBoxes: return - else: - current_roi = self.presenter.model.get_roi(self.current_roi_name) + current_roi = self.presenter.view.spectrum_widget.get_roi(self.current_roi_name) self.roiPropertiesGroupBox.setTitle(f"Roi Properties: {self.current_roi_name}") roi_iter_order = ["Left", "Top", "Right", "Bottom"] for row, pos in enumerate(current_roi):