Skip to content

Commit

Permalink
Fixes for Numpy 2 (#148)
Browse files Browse the repository at this point in the history
* np.NaN -> np.nan

* remove unneeded 'nyq' argument in firwin

* switch to default_rng

* some more changes

* fix random state copy

* deal with warning

* move to default_rng

* fix imports, no capitalization where not in style
  • Loading branch information
sappelhoff authored Jul 28, 2024
1 parent e4840d9 commit c6a92d1
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pyprep/_version.py export-subst
* text=auto
5 changes: 1 addition & 4 deletions .github/workflows/python_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
create:
branches: [ main ]
tags: [ '**' ]
schedule:
- cron: "0 4 * * MON"

Expand All @@ -18,7 +15,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.11"]
python-version: ["3.12"]
env:
TZ: Europe/Berlin
FORCE_COLOR: true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
create:
branches: [ main ]
tags: [ '**' ]
schedule:
- cron: "0 4 * * MON"

Expand All @@ -17,13 +14,13 @@ jobs:
fail-fast: false
matrix:
platform: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.11"]
python-version: ["3.12"]
mne-version: [mne-stable]

include:
# Test mne development version only on ubuntu
- platform: ubuntu-latest
python-version: "3.11"
python-version: "3.12"
mne-version: mne-main
run-as-extra: true

Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.11"
python: "3.12"

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand Down
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


.. image:: https://readthedocs.org/projects/pyprep/badge/?version=latest
:target: http://pyprep.readthedocs.io/en/latest/?badge=latest
:target: https://pyprep.readthedocs.io/en/latest/?badge=latest
:alt: Documentation Status


Expand All @@ -37,8 +37,8 @@ pyprep

For documentation, see the:

- `stable documentation <http://pyprep.readthedocs.io/en/stable/>`_
- `latest (development) documentation <http://pyprep.readthedocs.io/en/latest/>`_
- `stable documentation <https://pyprep.readthedocs.io/en/stable/>`_
- `latest (development) documentation <https://pyprep.readthedocs.io/en/latest/>`_

.. docs_readme_include_label
Expand Down
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ Changelog
- Initial commit: 2018-04-12
- Miscellaneous changes

.. _Stefan Appelhoff: http://stefanappelhoff.com/
.. _Stefan Appelhoff: https://stefanappelhoff.com/
.. _Aamna Lawrence: https://github.com/AamnaLawrence
.. _Adam Li: https://github.com/adam2392/
.. _Christian O'Reilly: https://github.com/christian-oreilly
Expand Down
11 changes: 8 additions & 3 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""finds bad channels."""
from copy import copy

import mne
import numpy as np
from mne.utils import check_random_state, logger
Expand Down Expand Up @@ -579,7 +577,14 @@ def find_bad_by_ransac(
exclude_from_ransac = (
self.bad_by_correlation + self.bad_by_deviation + self.bad_by_dropout
)
rng = copy(self.random_state) if self.matlab_strict else self.random_state

if self.matlab_strict:
random_state = self.random_state.get_state()
rng = np.random.RandomState()
rng.set_state(random_state)
else:
rng = self.random_state

self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(
self.EEGFiltered,
self.sample_rate,
Expand Down
2 changes: 1 addition & 1 deletion pyprep/removeTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,5 @@ def runline(y, n, dn):
(np.multiply(np.arange(n + 1, n + npts + 1), a) + b), (npts, 1)
)
for i in range(0, len(y_line)):
y[i] = y[i] - y_line[i]
y[i] = y[i] - y_line[i, 0]
return y
8 changes: 4 additions & 4 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _mat_quantile(arr, q, axis=None):
# Sort the array in ascending order along the given axis (any NaNs go to the end)
# Return NaN if array is empty.
if len(arr) == 0:
return np.NaN
return np.nan
arr_sorted = np.sort(arr, axis=axis)

# Ensure array is either 1D or 2D
Expand Down Expand Up @@ -182,7 +182,7 @@ def _eeglab_create_highpass(cutoff, srate):
N = order + 1
filt = np.zeros(N)
filt[N // 2] = 1
filt -= firwin(N, transition, window="hamming", nyq=1)
filt -= firwin(N, transition, window="hamming")
return filt


Expand Down Expand Up @@ -215,7 +215,7 @@ def _eeglab_fir_filter(data, filt):
pad_len = min(group_delay, n_samples)

# Prepare initial state of filter, using padding at start of data
start_pad_idx = np.zeros(pad_len, dtype=np.uint8)
start_pad_idx = np.zeros(pad_len, dtype=np.uint)
start_padded = np.concatenate((data[:, start_pad_idx], data[:, :pad_len]), axis=1)
zi_init = lfilter_zi(filt, 1) * np.take(start_padded, [0], axis=0)
_, zi = lfilter(filt, 1, start_padded, axis=1, zi=zi_init)
Expand All @@ -232,7 +232,7 @@ def _eeglab_fir_filter(data, filt):
)

# Finish filtering data, using padding at end to calculate final values
end_pad_idx = np.zeros(pad_len, dtype=np.uint8) + (n_samples - 1)
end_pad_idx = np.zeros(pad_len, dtype=np.uint) + (n_samples - 1)
end, _ = lfilter(filt, 1, data[:, end_pad_idx], axis=1, zi=zi)
out[:, (n_samples - pad_len) :] = end[:, (group_delay - pad_len) :]

Expand Down
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def make_random_mne_object(
n_freq_comps=5,
freq_range=[10, 60],
scale=1e-6,
RNG=np.random.RandomState(1337),
rng=np.random.default_rng(1337),
):
"""Make a random MNE object to use for testing.
Expand All @@ -98,8 +98,9 @@ def make_random_mne_object(
scale : float
Scaling factor applied to the signal in volts. For example 1e-6 to
get microvolts.
RNG : np.random.RandomState
Random state seed.
rng : np.random.Generator
The random number generator object. Must be created with
``np.random.default_rng``.
Returns
-------
Expand All @@ -120,8 +121,8 @@ def make_random_mne_object(
high = freq_range[1]
for chan in range(n_chans):
# Each channel signal is a sum of random freq sine waves
for freq_i in range(n_freq_comps):
freq = RNG.randint(low, high, signal_len)
for _ in range(n_freq_comps):
freq = rng.integers(low, high, signal_len)
signal[chan, :] += np.sin(2 * np.pi * times * freq)

signal *= scale # scale
Expand Down
25 changes: 12 additions & 13 deletions tests/test_find_noisy_channels.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Test the find_noisy_channels module."""
import numpy as np
import pytest
from numpy.random import RandomState

from pyprep.find_noisy_channels import NoisyChannels
from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend

# Set a fixed random seed for reproducible test results

RNG = RandomState(30)
rng = np.random.default_rng(30)


# Define some fixtures and utility functions for use across multiple tests
Expand Down Expand Up @@ -47,7 +46,7 @@ def raw_tmp(raw_clean_detrend):
def _generate_signal(fmin, fmax, timepoints, fcount=1):
"""Generate an EEG signal from one or more sine waves in a frequency range."""
signal = np.zeros_like(timepoints)
for freq in RNG.randint(fmin, fmax + 1, fcount):
for freq in rng.integers(fmin, fmax + 1, fcount):
signal += np.sin(2 * np.pi * timepoints * freq)
return signal * 1e-6

Expand All @@ -59,7 +58,7 @@ def test_bad_by_nan(raw_tmp):
"""Test the detection of channels containing any NaN values."""
# Insert a NaN value into a random channel
n_chans = raw_tmp.get_data().shape[0]
nan_idx = int(RNG.randint(0, n_chans, 1))
nan_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[nan_idx, 3] = np.nan

# Test automatic detection of NaN channels on NoisyChannels init
Expand All @@ -75,7 +74,7 @@ def test_bad_by_flat(raw_tmp):
"""Test the detection of channels with flat or very weak signals."""
# Make the signal for a random channel extremely weak
n_chans = raw_tmp.get_data().shape[0]
flat_idx = int(RNG.randint(0, n_chans, 1))
flat_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[flat_idx, :] = raw_tmp.get_data()[flat_idx, :] * 1e-12

# Test automatic detection of flat channels on NoisyChannels init
Expand All @@ -100,7 +99,7 @@ def test_bad_by_deviation(raw_tmp):

# Make the signal for a random channel have a very high amplitude
n_chans = raw_tmp.get_data().shape[0]
high_dev_idx = int(RNG.randint(0, n_chans, 1))
high_dev_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[high_dev_idx, :] *= high_dev_factor

# Test detection of abnormally high-amplitude channels
Expand All @@ -126,7 +125,7 @@ def test_bad_by_hf_noise(raw_tmp):
"""Test detection of channels with high-frequency noise."""
# Add some noise between 70 & 80 Hz to the signal of a random channel
n_chans = raw_tmp.get_data().shape[0]
hf_noise_idx = int(RNG.randint(0, n_chans, 1))
hf_noise_idx = int(rng.integers(0, n_chans, 1)[0])
hf_noise = _generate_signal(70, 80, raw_tmp.times, 5) * 10
raw_tmp._data[hf_noise_idx, :] += hf_noise

Expand All @@ -148,7 +147,7 @@ def test_bad_by_dropout(raw_tmp):
"""Test detection of channels with excessive portions of flat signal."""
# Add large dropout portions to the signal of a random channel
n_chans, n_samples = raw_tmp.get_data().shape
dropout_idx = int(RNG.randint(0, n_chans, 1))
dropout_idx = int(rng.integers(0, n_chans, 1)[0])
x1, x2 = (int(n_samples / 10), int(2 * n_samples / 10))
raw_tmp._data[dropout_idx, x1:x2] = 0 # flatten 10% of signal

Expand All @@ -162,7 +161,7 @@ def test_bad_by_correlation(raw_tmp):
"""Test detection of channels that correlate poorly with others."""
# Replace a random channel's signal with uncorrelated values
n_chans, n_samples = raw_tmp.get_data().shape
low_corr_idx = int(RNG.randint(0, n_chans, 1))
low_corr_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[low_corr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5)

# Test detection of channels that correlate poorly with others
Expand All @@ -187,7 +186,7 @@ def test_bad_by_SNR(raw_tmp):
"""Test detection of channels that have low signal-to-noise ratios."""
# Replace a random channel's signal with uncorrelated values
n_chans = raw_tmp.get_data().shape[0]
low_snr_idx = int(RNG.randint(0, n_chans, 1))
low_snr_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[low_snr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5)

# Add some high-frequency noise to the uncorrelated channel
Expand All @@ -203,7 +202,7 @@ def test_bad_by_SNR(raw_tmp):
def test_find_bad_by_ransac(raw_tmp):
"""Test the RANSAC component of NoisyChannels."""
# Set a consistent random seed for all RANSAC runs
RANSAC_RNG = 435656
ransac_rng = 435656

# RANSAC identifies channels that go bad together and are highly correlated.
# Inserting highly correlated signal in channels 0 through 6 at 30 Hz
Expand All @@ -222,7 +221,7 @@ def test_find_bad_by_ransac(raw_tmp):
corr = {}
for name, args in test_matrix.items():
nd = NoisyChannels(
raw_tmp, do_detrend=False, random_state=RANSAC_RNG, matlab_strict=args[0]
raw_tmp, do_detrend=False, random_state=ransac_rng, matlab_strict=args[0]
)
nd.find_bad_by_ransac(channel_wise=args[1], max_chunk_size=args[2])
# Save bad channels and RANSAC correlation matrix for later comparison
Expand All @@ -247,7 +246,7 @@ def test_find_bad_by_ransac(raw_tmp):
assert not np.allclose(corr["by_window"], corr["by_window_strict"])

# Ensure that RANSAC doesn't change random state if in MATLAB-strict mode
rng = RandomState(RANSAC_RNG)
rng = np.random.RandomState(ransac_rng)
init_state = rng.get_state()[2]
nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng, matlab_strict=True)
nd.find_bad_by_ransac()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_prep_pipeline_non_eeg(raw, montage):
ch_types_non_eeg,
times,
sfreq,
RNG=np.random.RandomState(1337),
rng=np.random.default_rng(1337),
)

raw_copy.add_channels([raw_non_eeg], force_update_info=True)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test various helper functions."""
import numpy as np
import pytest
from numpy.random import RandomState

from pyprep.utils import (
_correlate_arrays,
Expand Down Expand Up @@ -64,7 +63,7 @@ def test_mat_quantile_iqr():

# Add NaNs to test data
tst_nan = tst.copy()
tst_nan[0, :] = np.NaN
tst_nan[0, :] = np.nan

# Create arrays containing MATLAB results for NaN test case
quantile_expected = np.asarray([0.9712, 0.9880, 0.9807])
Expand All @@ -91,7 +90,7 @@ def test_mat_quantile_iqr():
def test_get_random_subset():
"""Test the function for getting random channel subsets."""
# Generate test data
rng = RandomState(435656)
rng = np.random.RandomState(435656)
chans = range(1, 61)

# Compare random subset equivalence with MATLAB
Expand Down

0 comments on commit c6a92d1

Please sign in to comment.