Skip to content

Commit

Permalink
try to reduce _mad usage fully (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
sappelhoff authored Aug 24, 2024
1 parent 6652b8b commit 125d8f7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Version 0.5.0 (Unreleased)
Changelog
~~~~~~~~~
- :meth:`~pyprep.NoisyChannels.find_bad_by_nan_flat` now accepts a ``flat_threshold`` argument, by `Nabil Alibou`_ (:gh:`144`)
- changed _mad function in utils.py to use median_abs_deviation from the sciPy module, by `Ayush Agarwal`_ (:gh:`153`).
- replaced an internal implementation of the MAD algorithm with :func:`scipy.stats.median_abs_deviation`, by `Ayush Agarwal`_ (:gh:`153`) and `Stefan Appelhoff`_ (:gh:`154`)

Bug
~~~
Expand Down
13 changes: 7 additions & 6 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
from mne.utils import check_random_state, logger
from scipy import signal
from scipy.stats import median_abs_deviation

from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend
from pyprep.utils import _filter_design, _mad, _mat_iqr, _mat_quantile
from pyprep.utils import _filter_design, _mat_iqr, _mat_quantile


class NoisyChannels:
Expand Down Expand Up @@ -247,7 +248,7 @@ def find_bad_by_nan_flat(self, flat_threshold=1e-15):
nan_channels = self.ch_names_original[nan_channel_mask]

# Detect channels with flat or extremely weak signals
flat_by_mad = _mad(EEGData, axis=1) < flat_threshold
flat_by_mad = median_abs_deviation(EEGData, axis=1) < flat_threshold
flat_by_stdev = np.std(EEGData, axis=1) < flat_threshold
flat_channel_mask = flat_by_mad | flat_by_stdev
flat_channels = self.ch_names_original[flat_channel_mask]
Expand Down Expand Up @@ -336,8 +337,8 @@ def find_bad_by_hfnoise(self, HF_zscore_threshold=5.0):
# < 50 Hz amplitude for each channel and get robust z-scores of values
if self.sample_rate > 100:
noisiness = np.divide(
_mad(self.EEGData - self.EEGFiltered, axis=1),
_mad(self.EEGFiltered, axis=1),
median_abs_deviation(self.EEGData - self.EEGFiltered, axis=1),
median_abs_deviation(self.EEGFiltered, axis=1),
)
noise_median = np.nanmedian(noisiness)
noise_sd = np.median(np.abs(noisiness - noise_median)) * MAD_TO_SD
Expand Down Expand Up @@ -421,7 +422,7 @@ def find_bad_by_correlation(
channel_amplitudes[w, usable] = _mat_iqr(eeg_raw, axis=1) * IQR_TO_SD

# Check for any channel dropouts (flat signal) within the window
eeg_amplitude = _mad(eeg_filtered, axis=1)
eeg_amplitude = median_abs_deviation(eeg_filtered, axis=1)
dropout[w, usable] = eeg_amplitude == 0

# Exclude any dropout chans from further calculations (avoids div-by-zero)
Expand All @@ -431,7 +432,7 @@ def find_bad_by_correlation(
eeg_amplitude = eeg_amplitude[eeg_amplitude > 0]

# Get high-frequency noise ratios for the window
high_freq_amplitude = _mad(eeg_raw - eeg_filtered, axis=1)
high_freq_amplitude = median_abs_deviation(eeg_raw - eeg_filtered, axis=1)
noiselevels[w, usable] = high_freq_amplitude / eeg_amplitude

# Get inter-channel correlations for the window
Expand Down
18 changes: 0 additions & 18 deletions pyprep/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
_correlate_arrays,
_eeglab_create_highpass,
_get_random_subset,
_mad,
_mat_iqr,
_mat_quantile,
_mat_round,
Expand Down Expand Up @@ -150,20 +149,3 @@ def test_eeglab_create_highpass():
expected_val = 0.9961
actual_val = vals[len(vals) // 2]
assert np.isclose(expected_val, actual_val, atol=0.001)


def test_mad():
"""Test the median absolute deviation from the median (MAD) function."""
# Generate test data
tst = np.array([[1, 2, 3, 4, 8], [80, 10, 20, 30, 40], [100, 200, 800, 300, 400]])
expected = np.asarray([1, 10, 100])

# Compare output to expected results
assert all(np.equal(_mad(tst, axis=1), expected))
assert all(np.equal(_mad(tst.T, axis=0), expected))
assert _mad(tst) == 28 # Matches robust.mad from statsmodels

# Test exception with > 2-D arrays
tst = np.random.rand(3, 3, 3)
with pytest.raises(ValueError):
_mad(tst, axis=0)
31 changes: 0 additions & 31 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from psutil import virtual_memory
from scipy import linalg
from scipy.signal import firwin, lfilter, lfilter_zi
from scipy.stats import median_abs_deviation


def _union(list1, list2):
Expand Down Expand Up @@ -462,36 +461,6 @@ def _correlate_arrays(a, b, matlab_strict=False):
return np.diag(np.corrcoef(a, b)[:n_chan, n_chan:])


def _mad(x, axis=None):
"""Calculate median absolute deviations from the median (MAD) for an array.
Parameters
----------
x : np.ndarray
A 1-D or 2-D numeric array to summarize.
axis : {int, tuple of int, None}, optional
Axis along which MADs should be calculated. If ``None``, the MAD will
be calculated for the full input array. Defaults to ``None``.
Returns
-------
mad : scalar or np.ndarray
If no axis is specified, returns the MAD for the full input array as a
single numeric value. Otherwise, returns an ``np.ndarray`` containing
the MAD for each index along the specified axis.
"""
# Ensure array is either 1D or 2D
x = np.asarray(x)
if x.ndim > 2:
e = "Only 1D and 2D arrays are supported (input has {0} dimensions)"
raise ValueError(e.format(x.ndim))

# Calculate the median absolute deviation from the median
mad = median_abs_deviation(x, axis=axis)
return mad


def _filter_design(N_order, amp, freq):
"""Create FIR low-pass filter for EEG data using frequency sampling method.
Expand Down

0 comments on commit 125d8f7

Please sign in to comment.