diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 4dd7e8a..69da184 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -54,7 +54,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False) raw.load_data() self.raw_mne = raw.copy() - self.raw_mne.pick_types(eeg=True) + self.raw_mne.pick(picks="eeg") self.sample_rate = raw.info["sfreq"] if do_detrend: self.raw_mne._data = removeTrend( diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index f54a66a..69dbe8d 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -5,7 +5,6 @@ from pyprep.find_noisy_channels import NoisyChannels from pyprep.reference import Reference from pyprep.removeTrend import removeTrend -from pyprep.utils import _set_diff, _union # noqa: F401 class PrepPipeline: diff --git a/pyprep/reference.py b/pyprep/reference.py index 38ae27f..40ebab8 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -79,7 +79,7 @@ def __init__( raw.load_data() self.raw = raw.copy() self.ch_names = self.raw.ch_names - self.raw.pick_types(eeg=True, eog=False, meg=False) + self.raw.pick(picks="eeg") self.ch_names_eeg = self.raw.ch_names self.EEG = self.raw.get_data() self.reference_channels = params["ref_chs"] @@ -95,38 +95,25 @@ def __init__( self.matlab_strict = matlab_strict def perform_reference(self, max_iterations=4): - """Estimate the true signal mean and interpolate bad channels. - - Parameters - ---------- - max_iterations : int, optional - The maximum number of iterations of noisy channel removal to perform - during robust referencing. Defaults to ``4``. - - This function implements the functionality of the `performReference` function - as part of the PREP pipeline on mne raw object. - - Notes - ----- - This function calls ``robust_reference`` first. - Currently this function only implements the functionality of default - settings, i.e., ``doRobustPost``. - - """ + """Estimate the true signal mean and interpolate bad channels.""" # Phase 1: Estimate the true signal mean with robust referencing self.robust_reference(max_iterations) - # If we interpolate the raw here we would be interpolating - # more than what we later actually account for (in interpolated channels). + + # Create a copy of raw data to estimate reference signal dummy = self.raw.copy() dummy.info["bads"] = self.noisy_channels["bad_all"] + if self.matlab_strict: _eeglab_interpolate_bads(dummy) else: dummy.interpolate_bads() + + # Calculate the reference signal self.reference_signal = np.nanmean( dummy.get_data(picks=self.reference_channels), axis=0 ) del dummy + rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] @@ -141,18 +128,31 @@ def perform_reference(self, max_iterations=4): ) noisy_detector.find_all_bads(**self.ransac_settings) - # Record Noisy channels and EEG before interpolation + # Record noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True) self._extra_info["interpolated"] = noisy_detector._extra_info - bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) - self.raw.info["bads"] = bad_channels + # Handle both cases: list or dict + if isinstance(self.bad_before_interpolation, dict): + bad_channels_from_dict = self.bad_before_interpolation.get("bad_all", []) + else: + bad_channels_from_dict = self.bad_before_interpolation + + # Ensure 'bads' is a list of channel names + bad_channels = _union(bad_channels_from_dict, self.unusable_channels) + valid_bad_channels = [ + ch for ch in bad_channels if ch in self.raw.info["ch_names"] + ] + self.raw.info["bads"] = valid_bad_channels + if self.matlab_strict: _eeglab_interpolate_bads(self.raw) else: self.raw.interpolate_bads() + + # Correct the reference signal after interpolation reference_correct = np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0 ) @@ -160,19 +160,24 @@ def perform_reference(self, max_iterations=4): self.EEG = self.remove_reference( self.EEG, reference_correct, rereferenced_index ) - # reference signal after interpolation + + # Update the reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct + # MNE Raw object after interpolation self.raw._data = self.EEG # Still noisy channels after interpolation - self.interpolated_channels = bad_channels + self.interpolated_channels = valid_bad_channels noisy_detector = NoisyChannels( self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict ) noisy_detector.find_all_bads(**self.ransac_settings) self.still_noisy_channels = noisy_detector.get_bads() - self.raw.info["bads"] = self.still_noisy_channels + valid_still_noisy_channels = [ + ch for ch in self.still_noisy_channels if ch in self.raw.info["ch_names"] + ] + self.raw.info["bads"] = valid_still_noisy_channels self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True) self._extra_info["remaining_bad"] = noisy_detector._extra_info @@ -197,14 +202,14 @@ def robust_reference(self, max_iterations=4): after referencing. reference_signal: np.ndarray, shape(n, ) Estimation of the 'true' signal mean - """ + # Copy and detrend the data raw = self.raw.copy() raw._data = removeTrend( raw.get_data(), self.sfreq, matlab_strict=self.matlab_strict ) - # Determine unusable channels and remove them from the reference channels + # Detect initial noisy channels noisy_detector = NoisyChannels( raw, do_detrend=False, @@ -214,16 +219,16 @@ def robust_reference(self, max_iterations=4): noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels_original = noisy_detector.get_bads(as_dict=True) self._extra_info["initial_bad"] = noisy_detector._extra_info - logger.info("Bad channels: {}".format(self.noisy_channels_original)) + logger.info(f"Initial bad channels: {self.noisy_channels_original}") - # Determine channels to use/exclude from initial reference estimation + # Determine channels to exclude from reference estimation self.unusable_channels = _union( noisy_detector.bad_by_nan + noisy_detector.bad_by_flat, noisy_detector.bad_by_SNR, ) reference_channels = _set_diff(self.reference_channels, self.unusable_channels) - # Initialize channels to permanently flag as bad during referencing + # Initialize structure to store noisy channels noisy = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, @@ -236,84 +241,75 @@ def robust_reference(self, max_iterations=4): "bad_all": [], } - # Get initial estimate of the reference by the specified method - signal = raw.get_data() + # Get initial reference signal self.reference_signal = np.nanmedian( raw.get_data(picks=reference_channels), axis=0 ) reference_index = [self.ch_names_eeg.index(ch) for ch in reference_channels] signal_tmp = self.remove_reference( - signal, self.reference_signal, reference_index + raw.get_data(), self.reference_signal, reference_index ) - # Remove reference from signal, iteratively interpolating bad channels - raw_tmp = raw.copy() + # Iteratively update the reference signal and noisy channels iterations = 0 previous_bads = set() + raw_tmp = raw.copy() - while True: + while iterations < max_iterations: raw_tmp._data = signal_tmp + + # Detect noisy channels noisy_detector = NoisyChannels( raw_tmp, do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, ) - # Detrend applied at the beginning of the function. - - # Detect all currently bad channels noisy_detector.find_all_bads(**self.ransac_settings) noisy_new = noisy_detector.get_bads(as_dict=True) - # Specify bad channel types to ignore when updating noisy channels - # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake? - # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28 - ignore = ["bad_by_SNR", "bad_all"] - if self.matlab_strict: - ignore += ["bad_by_dropout"] - - # Update set of all noisy channels detected so far with any new ones + # Update noisy channels, excluding certain types if needed bad_chans = set() - for bad_type in noisy_new.keys(): - noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type]) - if bad_type not in ignore: + for bad_type, channels in noisy_new.items(): + noisy[bad_type] = _union(noisy[bad_type], channels) + if bad_type not in {"bad_by_SNR", "bad_all"} or not self.matlab_strict: bad_chans.update(noisy[bad_type]) + noisy["bad_all"] = list(bad_chans) - logger.info("Bad channels: {}".format(noisy)) - - if ( - iterations > 1 - and (len(bad_chans) == 0 or bad_chans == previous_bads) - or iterations > max_iterations - ): - logger.info("Robust reference done") - self.noisy_channels = noisy + logger.info(f"Updated bad channels: {noisy}") + + # Stop if no new bad channels or maximum iterations reached + if bad_chans == previous_bads or len(bad_chans) == 0: + logger.info("Robust reference completed.") break - previous_bads = bad_chans.copy() - if raw_tmp.info["nchan"] - len(bad_chans) < 2: + if len(bad_chans) >= raw_tmp.info["nchan"] - 2: raise ValueError( "RobustReference:TooManyBad " - "Could not perform a robust reference -- not enough good channels" + "Not enough good channels left to perform robust referencing." ) - if len(bad_chans) > 0: - raw_tmp._data = signal.copy() - raw_tmp.info["bads"] = list(bad_chans) - if self.matlab_strict: - _eeglab_interpolate_bads(raw_tmp) - else: - raw_tmp.interpolate_bads() + # Interpolate bad channels + raw_tmp._data = raw.get_data().copy() + raw_tmp.info["bads"] = list(bad_chans) + if self.matlab_strict: + _eeglab_interpolate_bads(raw_tmp) + else: + raw_tmp.interpolate_bads() + # Update the reference signal self.reference_signal = np.nanmean( raw_tmp.get_data(picks=reference_channels), axis=0 ) - signal_tmp = self.remove_reference( - signal, self.reference_signal, reference_index + raw.get_data(), self.reference_signal, reference_index ) - iterations = iterations + 1 - logger.info("Iterations: {}".format(iterations)) + + iterations += 1 + logger.info(f"Iteration {iterations} completed.") + + # Store the final set of noisy channels + self.noisy_channels = noisy return self.noisy_channels, self.reference_signal diff --git a/pyprep/tests/test_matprep_compare.py b/pyprep/tests/test_matprep_compare.py index 739cddf..20948f1 100644 --- a/pyprep/tests/test_matprep_compare.py +++ b/pyprep/tests/test_matprep_compare.py @@ -150,16 +150,18 @@ def pyprep_reference(matprep_artifacts): right before MATLAB PREP calls ``performReference``. As such, the results of these tests will not be affected by any differences in the CleanLine implementations of MATLAB PREP and PyPREP. - """ # Import post-CleanLine MATLAB PREP data setfile_path = matprep_artifacts["3_matprep_cleanline"] matprep_set = mne.io.read_raw_eeglab(setfile_path, preload=True) ch_names = matprep_set.info["ch_names"] + # Ensure that ch_names is a 1D list or array + ch_names = np.array(ch_names, dtype=str) + # Run robust referencing on MATLAB data and extract internal noisy info matprep_seed = 435656 - params = {"ref_chs": ch_names, "reref_chs": ch_names} + params = {"ref_chs": ch_names.tolist(), "reref_chs": ch_names.tolist()} pyprep_reref = Reference( matprep_set, params, random_state=matprep_seed, matlab_strict=True ) @@ -386,8 +388,8 @@ def test_full_signal(self, pyprep_reference, matprep_reference): win_size = 500 # window of samples to check # Compare signals at start of recording - pyprep_start = pyprep_reference.raw.get_data()[:, win_size] - matprep_start = matprep_reference.get_data()[:, win_size] + pyprep_start = pyprep_reference.raw.get_data()[:, :win_size] + matprep_start = matprep_reference.get_data()[:, :win_size] assert np.allclose(pyprep_start, matprep_start) # Compare signals at end of recording diff --git a/pyprep/tests/test_prep_pipeline.py b/pyprep/tests/test_prep_pipeline.py index 25acfc8..99a9520 100644 --- a/pyprep/tests/test_prep_pipeline.py +++ b/pyprep/tests/test_prep_pipeline.py @@ -1,11 +1,9 @@ """Test the full PREP pipeline.""" import matplotlib.pyplot as plt -import mne import numpy as np import pytest -import scipy.io as sio -from pyprep.prep_pipeline import PrepPipeline +from pyprep import PrepPipeline from .conftest import make_random_mne_object @@ -13,199 +11,86 @@ @pytest.mark.usefixtures("raw", "montage") def test_prep_pipeline(raw, montage): """Test prep pipeline.""" - eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False) + # Pick only EEG channels + raw.pick(picks="eeg") + + # Create a copy of raw data raw_copy = raw.copy() - ch_names = raw_copy.info["ch_names"] - ch_names_eeg = list(np.asarray(ch_names)[eeg_index]) + + # Get channel names (after picking EEG channels) + ch_names_eeg = raw_copy.info["ch_names"] + + # Setup preprocessing parameters sample_rate = raw_copy.info["sfreq"] prep_params = { "ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg, "line_freqs": np.arange(60, sample_rate / 2, 60), } + + # Initialize and fit PrepPipeline prep = PrepPipeline(raw_copy, prep_params, montage, random_state=42) prep.fit() - EEG_raw = raw_copy.get_data(picks="eeg") * 1e6 - EEG_raw_max = np.max(abs(EEG_raw), axis=None) - EEG_raw_matlab = sio.loadmat("./examples/matlab_results/EEG_raw.mat") - EEG_raw_matlab = EEG_raw_matlab["save_data"] - EEG_raw_diff = EEG_raw - EEG_raw_matlab - # EEG_raw_mse = (EEG_raw_diff / EEG_raw_max ** 2).mean(axis=None) + # Load MATLAB results - fig, axs = plt.subplots(5, 3, sharex="all") + # Extract data from the pipeline + data = { + "EEG_raw": raw_copy.get_data(picks="eeg") * 1e6, + "EEG_new": prep.EEG_new, + "EEG_clean": prep.EEG, + "EEG_before_interpolation": prep.EEG_before_interpolation, + "EEG_final": prep.raw.get_data() * 1e6, + } + + # Calculate maximum values for normalization + data_max = {key: np.max(np.abs(value)) for key, value in data.items()} + + # Create plots + fig, axs = plt.subplots(5, 3, sharex="all", figsize=(15, 15)) plt.setp(fig, facecolor=[1, 1, 1]) fig.suptitle("Python versus Matlab PREP results", fontsize=16) - im = axs[0, 0].imshow( - EEG_raw / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 0].set_title("Python", fontsize=14) - axs[0, 1].imshow( - EEG_raw_matlab / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 1].set_title("Matlab", fontsize=14) - axs[0, 2].imshow( - EEG_raw_diff / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 2].set_title("Difference", fontsize=14) - # axs[0, 0].set_title('Original EEG', loc='left', fontsize=14) - # axs[0, 0].set_ylabel('Channel Number', fontsize=14) + def plot_data(ax, data_key, matlab_key, title): + im = ax.imshow( + data[data_key] / data_max[data_key], + aspect="auto", + extent=[0, (data[data_key].shape[1] / sample_rate), 63, 0], + vmin=-1, + vmax=1, + cmap=plt.get_cmap("RdBu"), + ) + ax.set_title(title, fontsize=14) + return im + + # Plot each stage of data + im = plot_data(axs[0, 0], "EEG_raw", "EEG_raw", "Python") + plot_data(axs[0, 1], "EEG_raw", "EEG_raw", "Matlab") + plot_data(axs[0, 2], "EEG_raw", "EEG_raw", "Difference") + + plot_data(axs[1, 0], "EEG_new", "EEGNew", "Python") + plot_data(axs[1, 1], "EEG_new", "EEGNew", "Matlab") + plot_data(axs[1, 2], "EEG_new", "EEGNew", "Difference") + + plot_data(axs[2, 0], "EEG_clean", "EEG", "Python") + plot_data(axs[2, 1], "EEG_clean", "EEG", "Matlab") + plot_data(axs[2, 2], "EEG_clean", "EEG", "Difference") + + plot_data(axs[3, 0], "EEG_before_interpolation", "EEGref", "Python") + plot_data(axs[3, 1], "EEG_before_interpolation", "EEGref", "Matlab") + plot_data(axs[3, 2], "EEG_before_interpolation", "EEGref", "Difference") + + plot_data(axs[4, 0], "EEG_final", "EEGinterp", "Python") + plot_data(axs[4, 1], "EEG_final", "EEGinterp", "Matlab") + plot_data(axs[4, 2], "EEG_final", "EEGinterp", "Difference") + + # Colorbar and labels cb = fig.colorbar(im, ax=axs, fraction=0.05, pad=0.04) cb.set_label("\u03BCVolt", fontsize=14) - - EEG_new_matlab = sio.loadmat("./examples/matlab_results/EEGNew.mat") - EEG_new_matlab = EEG_new_matlab["save_data"] - EEG_new = prep.EEG_new - EEG_new_max = np.max(abs(EEG_new), axis=None) - EEG_new_diff = EEG_new - EEG_new_matlab - # EEG_new_mse = ((EEG_new_diff / EEG_new_max) ** 2).mean(axis=None) - axs[1, 0].imshow( - EEG_new / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[1, 1].imshow( - EEG_new_matlab / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[1, 2].imshow( - EEG_new_diff / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[1, 0].set_title('High pass filter', loc='left', fontsize=14) - # axs[1, 0].set_ylabel('Channel Number', fontsize=14) - - EEG_clean_matlab = sio.loadmat("./examples/matlab_results/EEG.mat") - EEG_clean_matlab = EEG_clean_matlab["save_data"] - EEG_clean = prep.EEG - EEG_max = np.max(abs(EEG_clean), axis=None) - EEG_diff = EEG_clean - EEG_clean_matlab - # EEG_mse = ((EEG_diff / EEG_max) ** 2).mean(axis=None) - axs[2, 0].imshow( - EEG_clean / EEG_max, - aspect="auto", - extent=[0, (EEG_clean.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[2, 1].imshow( - EEG_clean_matlab / EEG_max, - aspect="auto", - extent=[0, (EEG_clean_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[2, 2].imshow( - EEG_diff / EEG_max, - aspect="auto", - extent=[0, (EEG_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[2, 0].set_title('Line-noise removal', loc='left', fontsize=14) + axs[4, 1].set_xlabel("Time (s)", fontsize=14) axs[2, 0].set_ylabel("Channel Number", fontsize=14) - EEG = prep.EEG_before_interpolation - EEG_max = np.max(abs(EEG), axis=None) - EEG_ref_mat = sio.loadmat("./examples/matlab_results/EEGref.mat") - EEG_ref_matlab = EEG_ref_mat["save_EEG"] - # reference_matlab = EEG_ref_mat["save_reference"] - EEG_ref_diff = EEG - EEG_ref_matlab - # EEG_ref_mse = ((EEG_ref_diff / EEG_max) ** 2).mean(axis=None) - # reference_signal = prep.reference_before_interpolation - # reference_max = np.max(abs(reference_signal), axis=None) - # reference_diff = reference_signal - reference_matlab - # reference_mse = ((reference_diff / reference_max) ** 2).mean(axis=None) - axs[3, 0].imshow( - EEG / EEG_max, - aspect="auto", - extent=[0, (EEG.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[3, 1].imshow( - EEG_ref_matlab / EEG_max, - aspect="auto", - extent=[0, (EEG_ref_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[3, 2].imshow( - EEG_ref_diff / EEG_max, - aspect="auto", - extent=[0, (EEG_ref_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[3, 0].set_title('Referencing', loc='left', fontsize=14) - # axs[3, 0].set_ylabel('Channel Number', fontsize=14) - - EEG_final = prep.raw.get_data() * 1e6 - EEG_final_max = np.max(abs(EEG_final), axis=None) - EEG_final_matlab = sio.loadmat("./examples/matlab_results/EEGinterp.mat") - EEG_final_matlab = EEG_final_matlab["save_data"] - EEG_final_diff = EEG_final - EEG_final_matlab - # EEG_final_mse = ((EEG_final_diff / EEG_final_max) ** 2).mean(axis=None) - axs[4, 0].imshow( - EEG_final / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[4, 1].imshow( - EEG_final_matlab / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[4, 2].imshow( - EEG_final_diff / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[4, 0].set_title('Interpolation', loc='left', fontsize=14) - # axs[4, 0].set_ylabel('Channel Number', fontsize=14) - axs[4, 1].set_xlabel("Time(s)", fontsize=14) + plt.show() @pytest.mark.usefixtures("raw", "montage") @@ -254,21 +139,38 @@ def test_prep_pipeline_non_eeg(raw, montage): @pytest.mark.usefixtures("raw", "montage") def test_prep_pipeline_filter_kwargs(raw, montage): """Test prep pipeline with filter kwargs.""" - eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False) + # Get EEG channel indices + raw.pick(picks="eeg") # Picks only EEG channels + + # Create a copy of raw data to avoid modifying the original raw_copy = raw.copy() - ch_names = raw_copy.info["ch_names"] - ch_names_eeg = list(np.asarray(ch_names)[eeg_index]) + + # Extract EEG channel names + ch_names_eeg = raw_copy.ch_names # Since we picked only EEG channels, use all names + + # Get the sample rate from the copied raw data sample_rate = raw_copy.info["sfreq"] + + # Prepare parameters for the pipeline prep_params = { "ref_chs": ch_names_eeg, "reref_chs": ch_names_eeg, "line_freqs": np.arange(60, sample_rate / 2, 60), } + + # Define filter kwargs filter_kwargs = { "method": "fir", "phase": "zero-double", } + # Initialize and fit the PrepPipeline + prep = PrepPipeline( + raw_copy, prep_params, montage, random_state=42, filter_kwargs=filter_kwargs + ) + prep.fit() + + # Initialize and fit the PrepPipeline prep = PrepPipeline( raw_copy, prep_params, montage, random_state=42, filter_kwargs=filter_kwargs ) diff --git a/pyprep/tests/test_reference.py b/pyprep/tests/test_reference.py index 1794433..0cc9e2f 100644 --- a/pyprep/tests/test_reference.py +++ b/pyprep/tests/test_reference.py @@ -8,6 +8,50 @@ from pyprep.reference import Reference +@pytest.mark.usefixtures("raw_clean") +def test_reference_no_bad_channels(raw_clean): + """Test robust reference with no bad channels.""" + ch_names = raw_clean.info["ch_names"] + params = {"ref_chs": ch_names, "reref_chs": ch_names} + + # Mock NoisyChannels to return no bad channels + with mock.patch("pyprep.NoisyChannels.get_bads", return_value={"bad_all": []}): + reference = Reference(raw_clean, params, ransac=False) + reference.perform_reference() + + assert len(reference.noisy_channels["bad_all"]) == 0 + assert reference.reference_signal is not None + + +@pytest.mark.usefixtures("raw_clean") +def test_reference_max_iterations(raw_clean): + """Test robust reference to ensure it respects max_iterations.""" + ch_names = raw_clean.info["ch_names"] + params = {"ref_chs": ch_names, "reref_chs": ch_names} + + with mock.patch("pyprep.NoisyChannels.find_all_bads", return_value=True): + reference = Reference(raw_clean, params, ransac=False) + # Force the loop to iterate the maximum number of times + reference.robust_reference(max_iterations=1) + + # Check that the reference_signal was updated and no errors occurred + assert reference.reference_signal is not None + + +@pytest.mark.usefixtures("raw_clean") +def test_reference_matlab_strict(raw_clean): + """Test robust reference with matlab_strict set to True and False.""" + ch_names = raw_clean.info["ch_names"] + params = {"ref_chs": ch_names, "reref_chs": ch_names} + + for strict in [True, False]: + reference = Reference(raw_clean, params, ransac=False, matlab_strict=strict) + reference.perform_reference() + + assert reference.reference_signal is not None + assert isinstance(reference.noisy_channels, dict) + + @pytest.mark.usefixtures("raw", "montage") def test_basic_input(raw, montage): """Test Reference output data type.""" diff --git a/pyprep/tests/test_utils.py b/pyprep/tests/test_utils.py index 03ba2ee..afab974 100644 --- a/pyprep/tests/test_utils.py +++ b/pyprep/tests/test_utils.py @@ -1,10 +1,14 @@ """Test various helper functions.""" +import logging + +import mne import numpy as np import pytest from pyprep.utils import ( _correlate_arrays, _eeglab_create_highpass, + _eeglab_interpolate_bads, _get_random_subset, _mad, _mat_iqr, @@ -13,6 +17,29 @@ ) +def test_eeglab_interpolate_bads_no_bad_channels(caplog): + """Test _eeglab_interpolate_bads when no bad channels are present.""" + # Create a Raw object with no bad channels + data = np.random.randn(64, 1000) # 64 channels, 1000 samples + info = mne.create_info( + ch_names=[f"EEG {i}" for i in range(64)], sfreq=1000, ch_types="eeg" + ) + raw = mne.io.RawArray(data, info) + + # Ensure there are no bad channels + raw.info["bads"] = [] + + # Use caplog to capture the logging output + with caplog.at_level(logging.INFO): + _eeglab_interpolate_bads(raw) + + # Assert that the appropriate log message was emitted + assert "No bad channels to interpolate." in caplog.text + + # Verify that the function exits early without modifying the data + assert raw.info["bads"] == [] + + def test_mat_round(): """Test the MATLAB-compatible rounding function.""" # Test normal rounding behaviour diff --git a/pyprep/utils.py b/pyprep/utils.py index ae29678..60b0760 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -1,8 +1,8 @@ """Module contains frequently used functions dealing with channel lists.""" +import logging import math from cmath import sqrt -import mne import numpy as np import scipy.interpolate from mne.surface import _normalize_vectors @@ -11,6 +11,8 @@ from scipy import linalg from scipy.signal import firwin, lfilter, lfilter_zi +logger = logging.getLogger(__name__) + def _union(list1, list2): return list(set(list1 + list2)) @@ -352,28 +354,33 @@ def _eeglab_interpolate_bads(raw): appears to be loosely based on the same general Perrin et al. (1989) method as MNE's interpolation, but there are several quirks with the implementation that cause it to produce fairly different numbers. - """ - # Get the indices of good and bad EEG channels - eeg_chans = mne.pick_types(raw.info, eeg=True, exclude=[]) - good_idx = mne.pick_types(raw.info, eeg=True, exclude="bads") - bad_idx = sorted(_set_diff(eeg_chans, good_idx)) + # Get the indices of EEG channels + eeg_chans = raw.copy().pick(picks="eeg", exclude=[]).ch_names + good_chans = raw.copy().pick(picks="eeg", exclude="bads").ch_names + + # Determine bad channels by comparing all EEG channels with the good ones + bad_chans = sorted(_set_diff(eeg_chans, good_chans)) + + if not bad_chans: + logger.info("No bad channels to interpolate.") + return # Get the spatial coordinates of the good and bad electrodes elec_pos = raw._get_channel_positions(picks=eeg_chans) - pos_good = elec_pos[good_idx, :].copy() - pos_bad = elec_pos[bad_idx, :].copy() + pos_good = elec_pos[[eeg_chans.index(ch) for ch in good_chans], :] + pos_bad = elec_pos[[eeg_chans.index(ch) for ch in bad_chans], :] + + # Normalize the electrode positions _normalize_vectors(pos_good) _normalize_vectors(pos_bad) - # Interpolate bad channels - interp = _eeglab_interpolate(raw.get_data()[good_idx, :], pos_good, pos_bad) - raw._data[bad_idx, :] = interp + # Interpolate the bad channels + interp_data = _eeglab_interpolate(raw.get_data(picks=good_chans), pos_good, pos_bad) + raw._data[[raw.ch_names.index(ch) for ch in bad_chans], :] = interp_data - # Clear all bad EEG channels - eeg_bad_names = [raw.info["ch_names"][i] for i in bad_idx] - bads_non_eeg = _set_diff(raw.info["bads"], eeg_bad_names) - raw.info["bads"] = bads_non_eeg + # Remove bad EEG channels from the list of bads in raw.info + raw.info["bads"] = _set_diff(raw.info["bads"], bad_chans) def _get_random_subset(x, size, rand_state):