From bedf9106b17864095466876060be0a675d8fb1a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ole=20Engstr=C3=B8m?= Date: Mon, 5 Aug 2024 00:59:15 +0200 Subject: [PATCH] Updated test tolerances due to MacOS. Also updated README. --- README.md | 9 +++++---- tests/test_ikpls.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 4fbd915..b7e1e5b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The `ikpls` software package provides fast and efficient tools for PLS (Partial Least Squares) modeling. This package is designed to help researchers and practitioners handle PLS modeling faster than previously possible - particularly on large datasets. ## Citation -If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software paper](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185). +If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software article](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185). ## Unlock the Power of Fast and Stable Partial Least Squares Modeling with IKPLS @@ -52,9 +52,10 @@ and scaling can be enabled or disabled independently from eachother and for X an by setting the parameters `center_X`, `center_Y`, `scale_X`, and `scale_Y`, respectively. In addition to correctly handling (column-wise) centering and scaling, the fast cross-validation algorithm **correctly handles row-wise preprocessing** -such as (row-wise) centering and scaling of the X and Y input matrices, -convolution, or other preprocessing. Row-wise preprocessing can safely be -applied before passing the data to the fast cross-validation algorithm. +that operates independently on each sample such as (row-wise) centering and scaling +of the X and Y input matrices, convolution, or other preprocessing. Row-wise +preprocessing can safely be applied before passing the data to the fast +cross-validation algorithm. ## Prerequisites diff --git a/tests/test_ikpls.py b/tests/test_ikpls.py index 6a86dd1..61b92c9 100644 --- a/tests/test_ikpls.py +++ b/tests/test_ikpls.py @@ -895,8 +895,8 @@ def test_pls_1(self) -> None: jax_pls_alg_2=jax_pls_alg_2, diff_jax_pls_alg_1=diff_jax_pls_alg_1, diff_jax_pls_alg_2=diff_jax_pls_alg_2, - atol=1e-8, - rtol=6e-5, + atol=3e-8, + rtol=2e-4, ) self.check_predictions( @@ -3178,26 +3178,26 @@ def test_fast_cross_val_pls_1(self): splits = self.load_Y(["split"]) assert Y.shape[1] == 1 self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8 ) # Remove the singleton dimension and check that the predictions are consistent. Y = Y.squeeze() assert Y.ndim == 1 self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3243,13 +3243,13 @@ def test_fast_cross_val_pls_2_m_less_k(self): assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7 + X, Y, splits, center=False, scale=False, atol=0, rtol=2e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7 + X, Y, splits, center=True, scale=False, atol=0, rtol=2e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7 + X, Y, splits, center=True, scale=True, atol=0, rtol=2e-7 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3296,13 +3296,13 @@ def test_fast_cross_val_pls_2_m_eq_k(self): assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3448,13 +3448,13 @@ def test_fast_cross_val_pls_2_m_less_k_loocv(self): assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=2e-6, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=6e-6, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=5e-6, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=6e-6, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=3e-6, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=6e-6, rtol=2e-8 ) def test_fast_cross_val_pls_2_m_eq_k_loocv(self): @@ -3494,10 +3494,10 @@ def test_fast_cross_val_pls_2_m_eq_k_loocv(self): assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=1e-7, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=2e-7, rtol=1e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=1e-7, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=2e-7, rtol=1e-8 ) self.check_fast_cross_val_pls( X, Y, splits, center=True, scale=True, atol=1e-7, rtol=1e-8 @@ -4111,7 +4111,7 @@ def test_center_scale_combinations_pls_2_m_eq_k(self): splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] - self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8) + self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=3e-8) # JAX will issue a warning if os.fork() is called as JAX is incompatible with # multi-threaded code. os.fork() is called by the other cross-validation