Skip to content

Commit

Permalink
Bug channel annot merge (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmarkowitz authored Aug 18, 2024
1 parent 519c6e9 commit cfa0f60
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 7 deletions.
44 changes: 37 additions & 7 deletions mne_qt_browser/_pg_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,7 +2359,7 @@ def __init__(self, mne, weakmain, annot, ch_name):

self.mne.plt.addItem(self, ignoreBounds=True)

self.annot.removeRequested.connect(self.remove)
self.annot.removeSingleChannelAnnots.connect(self.remove)
self.annot.sigRegionChangeFinished.connect(self.update_plot_curves)
self.annot.sigRegionChanged.connect(self.update_plot_curves)
self.annot.sigToggleVisibility.connect(self.update_visible)
Expand Down Expand Up @@ -2399,6 +2399,7 @@ class AnnotRegion(LinearRegionItem):
regionChangeFinished = Signal(object)
gotSelected = Signal(object)
removeRequested = Signal(object)
removeSingleChannelAnnots = Signal(object)
sigToggleVisibility = Signal(bool)
sigUpdateColor = Signal(str)

Expand Down Expand Up @@ -2435,29 +2436,54 @@ def __init__(self, mne, description, values, weakmain, ch_names=None):
self.mne.plt.addItem(self.label_item, ignoreBounds=True)

def _region_changed(self):
self.regionChangeFinished.emit(self)
self.old_onset = self.getRegion()[0]
# remove merged regions
# Check for overlapping regions
overlap_has_sca = []
overlapping_regions = list()
for region in self.mne.regions:
if region.description != self.description or id(self) == id(region):
continue
values = region.getRegion()
if any(self.getRegion()[0] <= val <= self.getRegion()[1] for val in values):
if (
any(self.getRegion()[0] <= val <= self.getRegion()[1] for val in values)
or (values[0] <= self.getRegion()[0] <= values[1])
and (values[0] <= self.getRegion()[1] <= values[1])
):
overlapping_regions.append(region)
overlap_has_sca.append(len(region.single_channel_annots) > 0)

# If this region or an overlapping region have
# channel specific annotations then terminate
if (len(self.single_channel_annots) > 0 or any(overlap_has_sca)) and len(
overlapping_regions
) > 0:
dur = self.getRegion()[1] - self.getRegion()[0]
self.setRegion((self.old_onset, self.old_onset + dur))
warn(
"Can not combine channel-based annotations with "
"any other annotation."
)
return

# figure out new boundaries
regions_ = np.array(
[region.getRegion() for region in overlapping_regions] + [self.getRegion()]
)

self.regionChangeFinished.emit(self)

onset = np.min(regions_[:, 0])
offset = np.max(regions_[:, 1])

self.old_onset = onset

logger.debug(f"New {self.description} region: {onset:.2f} - {offset:.2f}")
# remove overlapping regions
for region in overlapping_regions:
self.weakmain()._remove_region(region, from_annot=False)
# re-set while blocking the signal to avoid re-running this function
with SignalBlocker(self):
self.setRegion((onset, offset))

self.update_label_pos()

def _add_single_channel_annot(self, ch_name):
Expand All @@ -2469,7 +2495,7 @@ def _remove_single_channel_annot(self, ch_name):
self.single_channel_annots[ch_name].remove()
self.single_channel_annots.pop(ch_name)

def _toggle_single_channel_annot(self, ch_name):
def _toggle_single_channel_annot(self, ch_name, update_color=True):
"""Add or remove single channel annotations."""
# Exit if mne-python not updated to support shift-click
if not hasattr(self.weakmain(), "_toggle_single_channel_annotation"):
Expand All @@ -2486,7 +2512,10 @@ def _toggle_single_channel_annot(self, ch_name):
else:
self._remove_single_channel_annot(ch_name)

self.update_color(all_channels=(not list(self.single_channel_annots.keys())))
if update_color:
self.update_color(
all_channels=(not list(self.single_channel_annots.keys()))
)

def update_color(self, all_channels=True):
"""Update color of annotation-region.
Expand Down Expand Up @@ -2539,6 +2568,7 @@ def update_visible(self, visible):

def remove(self):
"""Remove annotation-region."""
self.removeSingleChannelAnnots.emit(self)
self.removeRequested.emit(self)
vb = self.mne.viewbox
if vb and self.label_item in vb.addedItems:
Expand Down
32 changes: 32 additions & 0 deletions mne_qt_browser/tests/test_pg_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ def test_annotations_interactions(raw_orig, pg_backend):
assert fig.msg_box.informativeText() == "Start can't be bigger or " "equal to Stop!"
fig.msg_box.close()

# Test that dragging annotation onto the tail of another works
annot_dock._remove_description("E")
annot_dock._remove_description("C")
fig._fake_click(
(4.0, 1.0), add_points=[(6.0, 1.0)], xform="data", button=1, kind="drag"
)
fig._fake_click(
(4.0, 1.0), add_points=[(3.0, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw_orig.annotations.onset) == 1
assert len(fig.mne.regions) == 1

# Make a smaller annotation and put it into the larger one
fig._fake_click(
(8.0, 1.0), add_points=[(8.1, 1.0)], xform="data", button=1, kind="drag"
)
fig._fake_click(
(8.0, 1.0), add_points=[(4.0, 1.0)], xform="data", button=1, kind="drag"
)
assert len(raw_orig.annotations.onset) == 1
assert len(fig.mne.regions) == 1


def test_ch_specific_annot(raw_orig, pg_backend):
"""Test plotting channel specific annotations."""
Expand Down Expand Up @@ -167,6 +189,16 @@ def test_ch_specific_annot(raw_orig, pg_backend):
modifier=Qt.ShiftModifier,
)
assert "MEG 0133" in annot.single_channel_annots.keys()

# Check that channel specific annotations do not merge
fig._fake_click(
(2.0, 1.0), add_points=[(3.0, 1.0)], xform="data", button=1, kind="drag"
)
with pytest.warns(RuntimeWarning, match="combine channel-based"):
fig._fake_click(
(2.1, 1.0), add_points=[(5.0, 1.0)], xform="data", button=1, kind="drag"
)

else:
# emit a warning if the user tries to test single channel annots
with pytest.warns(RuntimeWarning, match="updated"):
Expand Down

0 comments on commit cfa0f60

Please sign in to comment.