diff --git a/fracridge/fracridge.py b/fracridge/fracridge.py index 5f50371..c0edd58 100644 --- a/fracridge/fracridge.py +++ b/fracridge/fracridge.py @@ -22,7 +22,7 @@ "FracRidgeRegressorCV"] -def _do_svd(X, y, jit=True): +def _do_svd(X, y, mode=1): """ Helper function to produce SVD outputs """ @@ -32,7 +32,22 @@ def _do_svd(X, y, jit=True): # Per default, we'll try to use the jit-compiled SVD, which should be # more performant: use_scipy = False - if jit: + if mode == 2: + from functools import partial + try: + from jax.numpy.linalg import svd + svd = partial(svd, full_matrices=False) + + except ImportError: + warnings.warn("The `mode` key-word argument is set to `2` ", + "but jax could not be imported, or some ", + "dependencies were missing. ", + "compilation failed. Falling back to ", + "`scipy.linalg.svd`") + use_scipy = True + mode = 0 + + elif mode == 1: try: from ._linalg import svd except ImportError: @@ -43,7 +58,7 @@ def _do_svd(X, y, jit=True): use_scipy = True # If that doesn't work, or you asked not to, we'll use scipy SVD: - if not jit or use_scipy: + if not mode or use_scipy: from functools import partial from scipy.linalg import svd # noqa svd = partial(svd, full_matrices=False) @@ -63,7 +78,7 @@ def _do_svd(X, y, jit=True): return selt, v_t, ols_coef -def fracridge(X, y, fracs=None, tol=1e-10, jit=True): +def fracridge(X, y, fracs=None, tol=1e-10, mode=1): """ Approximates alpha parameters to match desired fractions of OLS length. @@ -82,10 +97,10 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True): to be sorted. Otherwise, raises ValueError. Default: np.arange(.1, 1.1, .1). - jit : bool, optional + mode : int, optional Whether to speed up computations by using a just-in-time compiled - version of core computations. This may not work well with very large - datasets. Default: True + version of core computations (mode=1), or using jax (mode=2). + This may not work well with very large datasets. Default: 1 Returns ------- @@ -145,7 +160,7 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True): ff = fracs.shape[0] # Calculate the rotation of the data - selt, v_t, ols_coef = _do_svd(X, y, jit=jit) + selt, v_t, ols_coef = _do_svd(X, y, mode=mode) # Set solutions for small eigenvalues to 0 for all targets: isbad = selt < tol @@ -229,8 +244,9 @@ class FracRidgeRegressor(BaseEstimator, MultiOutputMixin): Tolerance under which singular values of the X matrix are considered to be zero. Default: 1e-10. - jit : bool, optional. - Whether to use jit-accelerated implementation. Default: True. + mode : int, optional. + Whether to use jit-accelerated implementation (mode=1), jax (mode=2), + or revert to scipy (mode=0). Default: 1. Attributes ---------- @@ -274,13 +290,13 @@ class FracRidgeRegressor(BaseEstimator, MultiOutputMixin): 0.29 """ def __init__(self, fracs=None, fit_intercept=False, normalize=False, - copy_X=True, tol=1e-10, jit=True): + copy_X=True, tol=1e-10, mode=1): self.fracs = fracs self.fit_intercept = fit_intercept self.normalize = normalize self.copy_X = copy_X self.tol = tol - self.jit = jit + self.mode = mode def _validate_input(self, X, y, sample_weight=None): """ @@ -306,7 +322,7 @@ def fit(self, X, y, sample_weight=None): X, y, X_offset, y_offset, X_scale = self._validate_input( X, y, sample_weight=sample_weight) coef, alpha = fracridge(X, y, fracs=self.fracs, tol=self.tol, - jit=self.jit) + mode=self.mode) self.alpha_ = alpha self.coef_ = coef self._set_intercept(X_offset, y_offset, X_scale) @@ -376,8 +392,9 @@ class FracRidgeRegressorCV(FracRidgeRegressor): Tolerance under which singular values of the X matrix are considered to be zero. Default: 1e-10. - jit : bool, optional. - Whether to use jit-accelerated implementation. Default: True. + mode : bool, optional. + Whether to use jit-accelerated implementation (mode=1), or + jax (mode=2), otherwise revert to scipy. Default: 1. cv : int, cross-validation generator or an iterable See https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html # noqa @@ -415,13 +432,13 @@ class FracRidgeRegressorCV(FracRidgeRegressor): 0.1 """ def __init__(self, frac_grid=None, fit_intercept=False, normalize=False, - copy_X=True, tol=1e-10, jit=True, cv=None, scoring=None): + copy_X=True, tol=1e-10, mode=1, cv=None, scoring=None): self.frac_grid = frac_grid if self.frac_grid is None: self.frac_grid = np.arange(.1, 1.1, .1) super().__init__(self, fit_intercept=False, normalize=False, - copy_X=True, tol=tol, jit=True) + copy_X=True, tol=tol, mode=1) self.cv = cv self.scoring = scoring @@ -436,7 +453,7 @@ def fit(self, X, y, sample_weight=None): normalize=self.normalize, copy_X=self.copy_X, tol=self.tol, - jit=self.jit), + mode=self.mode), parameters, cv=self.cv, scoring=self.scoring) gs.fit(X, y, sample_weight=sample_weight) diff --git a/fracridge/tests/test_fracridge.py b/fracridge/tests/test_fracridge.py index 01a4b3f..6e66fbc 100644 --- a/fracridge/tests/test_fracridge.py +++ b/fracridge/tests/test_fracridge.py @@ -6,17 +6,17 @@ import pytest -def run_fracridge(X, y, fracs, jit): - fracridge(X, y, fracs=fracs, jit=jit) +def run_fracridge(X, y, fracs, mode): + fracridge(X, y, fracs=fracs, mode=mode) @pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)]) @pytest.mark.parametrize("bb", [(1), (2), (1000)]) -@pytest.mark.parametrize("jit", [True, False]) -def test_benchmark_fracridge(nn, pp, bb, jit, benchmark): +@pytest.mark.parametrize("mode", [0, 1, 2]) +def test_benchmark_fracridge(nn, pp, bb, mode, benchmark): X, y, _, _ = make_data(nn, pp, bb) fracs = np.arange(.1, 1.1, .1) - benchmark(run_fracridge, X, y, fracs, jit) + benchmark(run_fracridge, X, y, fracs, mode) def make_data(nn, pp, bb, fit_intercept=False): @@ -89,11 +89,12 @@ def test_v_fracs(nn, pp, bb, frac): @pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)]) @pytest.mark.parametrize("bb", [(1), (2), (1000)]) @pytest.mark.parametrize("fit_intercept", [True, False]) -@pytest.mark.parametrize("jit", [True, False]) -def test_FracRidgeRegressor_predict(nn, pp, bb, fit_intercept, jit): +@pytest.mark.parametrize("mode", [0, 1, 2]) +def test_FracRidgeRegressor_predict(nn, pp, bb, fit_intercept, mode): X, y, coef_ols, pred_ols = make_data(nn, pp, bb, fit_intercept) fracs = np.arange(.1, 1.1, .1) - FR = FracRidgeRegressor(fracs=fracs, fit_intercept=fit_intercept, jit=jit) + FR = FracRidgeRegressor( + fracs=fracs, fit_intercept=fit_intercept, mode=mode) FR.fit(X, y) pred_fr = FR.predict(X) assert np.allclose(pred_fr[:, -1, ...], pred_ols, atol=10e-3) @@ -112,12 +113,12 @@ def test_FracRidge_singleton_frac(): @pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)]) @pytest.mark.parametrize("bb", [(1), (2), (1000)]) @pytest.mark.parametrize("fit_intercept", [True, False]) -@pytest.mark.parametrize("jit", [True, False]) -def test_FracRidgeRegressorCV(nn, pp, bb, fit_intercept, jit): +@pytest.mark.parametrize("mode", [0, 1, 2]) +def test_FracRidgeRegressorCV(nn, pp, bb, fit_intercept, mode): X, y, _, _ = make_data(nn, pp, bb, fit_intercept) fracs = np.arange(.1, 1.1, .1) FRCV = FracRidgeRegressorCV(frac_grid=fracs, fit_intercept=fit_intercept, - jit=jit) + mode=mode) FRCV.fit(X, y) FR = FracRidgeRegressor(fracs=FRCV.best_frac_) FR.fit(X, y) diff --git a/requirements.txt b/requirements.txt index f8551b5..aacc05e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ scikit-learn numba setuptools_scm pillow +jax