Skip to content

Commit

Permalink
Updated test tolerances due to MacOS. Also updated README.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sm00thix committed Aug 4, 2024
1 parent a284782 commit bedf910
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
The `ikpls` software package provides fast and efficient tools for PLS (Partial Least Squares) modeling. This package is designed to help researchers and practitioners handle PLS modeling faster than previously possible - particularly on large datasets.

## Citation
If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software paper](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185).
If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software article](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185).

## Unlock the Power of Fast and Stable Partial Least Squares Modeling with IKPLS

Expand Down Expand Up @@ -52,9 +52,10 @@ and scaling can be enabled or disabled independently from eachother and for X an
by setting the parameters `center_X`, `center_Y`, `scale_X`, and `scale_Y`, respectively.
In addition to correctly handling (column-wise) centering and scaling,
the fast cross-validation algorithm **correctly handles row-wise preprocessing**
such as (row-wise) centering and scaling of the X and Y input matrices,
convolution, or other preprocessing. Row-wise preprocessing can safely be
applied before passing the data to the fast cross-validation algorithm.
that operates independently on each sample such as (row-wise) centering and scaling
of the X and Y input matrices, convolution, or other preprocessing. Row-wise
preprocessing can safely be applied before passing the data to the fast
cross-validation algorithm.

## Prerequisites

Expand Down
40 changes: 20 additions & 20 deletions tests/test_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,8 @@ def test_pls_1(self) -> None:
jax_pls_alg_2=jax_pls_alg_2,
diff_jax_pls_alg_1=diff_jax_pls_alg_1,
diff_jax_pls_alg_2=diff_jax_pls_alg_2,
atol=1e-8,
rtol=6e-5,
atol=3e-8,
rtol=2e-4,
)

self.check_predictions(
Expand Down Expand Up @@ -3178,26 +3178,26 @@ def test_fast_cross_val_pls_1(self):
splits = self.load_Y(["split"])
assert Y.shape[1] == 1
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8
)

# Remove the singleton dimension and check that the predictions are consistent.
Y = Y.squeeze()
assert Y.ndim == 1
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3243,13 +3243,13 @@ def test_fast_cross_val_pls_2_m_less_k(self):
assert Y.shape[1] > 1
assert Y.shape[1] < X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7
X, Y, splits, center=False, scale=False, atol=0, rtol=2e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7
X, Y, splits, center=True, scale=False, atol=0, rtol=2e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7
X, Y, splits, center=True, scale=True, atol=0, rtol=2e-7
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3296,13 +3296,13 @@ def test_fast_cross_val_pls_2_m_eq_k(self):
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3448,13 +3448,13 @@ def test_fast_cross_val_pls_2_m_less_k_loocv(self):
assert Y.shape[1] > 1
assert Y.shape[1] < X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=2e-6, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=6e-6, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=5e-6, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=6e-6, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=3e-6, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=6e-6, rtol=2e-8
)

def test_fast_cross_val_pls_2_m_eq_k_loocv(self):
Expand Down Expand Up @@ -3494,10 +3494,10 @@ def test_fast_cross_val_pls_2_m_eq_k_loocv(self):
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=1e-7, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=2e-7, rtol=1e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=1e-7, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=2e-7, rtol=1e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=1e-7, rtol=1e-8
Expand Down Expand Up @@ -4111,7 +4111,7 @@ def test_center_scale_combinations_pls_2_m_eq_k(self):
splits = self.load_Y(["split"]) # Contains 3 splits of different sizes
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8)
self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=3e-8)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
# multi-threaded code. os.fork() is called by the other cross-validation
Expand Down

0 comments on commit bedf910

Please sign in to comment.