Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit of Jax integration [WIP] #23

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions fracridge/fracridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"FracRidgeRegressorCV"]


def _do_svd(X, y, jit=True):
def _do_svd(X, y, mode=1):
"""
Helper function to produce SVD outputs
"""
Expand All @@ -32,7 +32,22 @@ def _do_svd(X, y, jit=True):
# Per default, we'll try to use the jit-compiled SVD, which should be
# more performant:
use_scipy = False
if jit:
if mode == 2:
from functools import partial
try:
from jax.numpy.linalg import svd
svd = partial(svd, full_matrices=False)

except ImportError:
warnings.warn("The `mode` key-word argument is set to `2` ",
"but jax could not be imported, or some ",
"dependencies were missing. ",
"compilation failed. Falling back to ",
"`scipy.linalg.svd`")
use_scipy = True
mode = 0

elif mode == 1:
try:
from ._linalg import svd
except ImportError:
Expand All @@ -43,7 +58,7 @@ def _do_svd(X, y, jit=True):
use_scipy = True

# If that doesn't work, or you asked not to, we'll use scipy SVD:
if not jit or use_scipy:
if not mode or use_scipy:
from functools import partial
from scipy.linalg import svd # noqa
svd = partial(svd, full_matrices=False)
Expand All @@ -63,7 +78,7 @@ def _do_svd(X, y, jit=True):
return selt, v_t, ols_coef


def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
def fracridge(X, y, fracs=None, tol=1e-10, mode=1):
"""
Approximates alpha parameters to match desired fractions of OLS length.

Expand All @@ -82,10 +97,10 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
to be sorted. Otherwise, raises ValueError.
Default: np.arange(.1, 1.1, .1).

jit : bool, optional
mode : int, optional
Whether to speed up computations by using a just-in-time compiled
version of core computations. This may not work well with very large
datasets. Default: True
version of core computations (mode=1), or using jax (mode=2).
This may not work well with very large datasets. Default: 1

Returns
-------
Expand Down Expand Up @@ -145,7 +160,7 @@ def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
ff = fracs.shape[0]

# Calculate the rotation of the data
selt, v_t, ols_coef = _do_svd(X, y, jit=jit)
selt, v_t, ols_coef = _do_svd(X, y, mode=mode)

# Set solutions for small eigenvalues to 0 for all targets:
isbad = selt < tol
Expand Down Expand Up @@ -229,8 +244,9 @@ class FracRidgeRegressor(BaseEstimator, MultiOutputMixin):
Tolerance under which singular values of the X matrix are considered
to be zero. Default: 1e-10.

jit : bool, optional.
Whether to use jit-accelerated implementation. Default: True.
mode : int, optional.
Whether to use jit-accelerated implementation (mode=1), jax (mode=2),
or revert to scipy (mode=0). Default: 1.

Attributes
----------
Expand Down Expand Up @@ -274,13 +290,13 @@ class FracRidgeRegressor(BaseEstimator, MultiOutputMixin):
0.29
"""
def __init__(self, fracs=None, fit_intercept=False, normalize=False,
copy_X=True, tol=1e-10, jit=True):
copy_X=True, tol=1e-10, mode=1):
self.fracs = fracs
self.fit_intercept = fit_intercept
self.normalize = normalize
self.copy_X = copy_X
self.tol = tol
self.jit = jit
self.mode = mode

def _validate_input(self, X, y, sample_weight=None):
"""
Expand All @@ -306,7 +322,7 @@ def fit(self, X, y, sample_weight=None):
X, y, X_offset, y_offset, X_scale = self._validate_input(
X, y, sample_weight=sample_weight)
coef, alpha = fracridge(X, y, fracs=self.fracs, tol=self.tol,
jit=self.jit)
mode=self.mode)
self.alpha_ = alpha
self.coef_ = coef
self._set_intercept(X_offset, y_offset, X_scale)
Expand Down Expand Up @@ -376,8 +392,9 @@ class FracRidgeRegressorCV(FracRidgeRegressor):
Tolerance under which singular values of the X matrix are considered
to be zero. Default: 1e-10.

jit : bool, optional.
Whether to use jit-accelerated implementation. Default: True.
mode : bool, optional.
Whether to use jit-accelerated implementation (mode=1), or
jax (mode=2), otherwise revert to scipy. Default: 1.

cv : int, cross-validation generator or an iterable
See https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html # noqa
Expand Down Expand Up @@ -415,13 +432,13 @@ class FracRidgeRegressorCV(FracRidgeRegressor):
0.1
"""
def __init__(self, frac_grid=None, fit_intercept=False, normalize=False,
copy_X=True, tol=1e-10, jit=True, cv=None, scoring=None):
copy_X=True, tol=1e-10, mode=1, cv=None, scoring=None):

self.frac_grid = frac_grid
if self.frac_grid is None:
self.frac_grid = np.arange(.1, 1.1, .1)
super().__init__(self, fit_intercept=False, normalize=False,
copy_X=True, tol=tol, jit=True)
copy_X=True, tol=tol, mode=1)
self.cv = cv
self.scoring = scoring

Expand All @@ -436,7 +453,7 @@ def fit(self, X, y, sample_weight=None):
normalize=self.normalize,
copy_X=self.copy_X,
tol=self.tol,
jit=self.jit),
mode=self.mode),
parameters, cv=self.cv, scoring=self.scoring)

gs.fit(X, y, sample_weight=sample_weight)
Expand Down
23 changes: 12 additions & 11 deletions fracridge/tests/test_fracridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import pytest


def run_fracridge(X, y, fracs, jit):
fracridge(X, y, fracs=fracs, jit=jit)
def run_fracridge(X, y, fracs, mode):
fracridge(X, y, fracs=fracs, mode=mode)


@pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)])
@pytest.mark.parametrize("bb", [(1), (2), (1000)])
@pytest.mark.parametrize("jit", [True, False])
def test_benchmark_fracridge(nn, pp, bb, jit, benchmark):
@pytest.mark.parametrize("mode", [0, 1, 2])
def test_benchmark_fracridge(nn, pp, bb, mode, benchmark):
X, y, _, _ = make_data(nn, pp, bb)
fracs = np.arange(.1, 1.1, .1)
benchmark(run_fracridge, X, y, fracs, jit)
benchmark(run_fracridge, X, y, fracs, mode)


def make_data(nn, pp, bb, fit_intercept=False):
Expand Down Expand Up @@ -89,11 +89,12 @@ def test_v_fracs(nn, pp, bb, frac):
@pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)])
@pytest.mark.parametrize("bb", [(1), (2), (1000)])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("jit", [True, False])
def test_FracRidgeRegressor_predict(nn, pp, bb, fit_intercept, jit):
@pytest.mark.parametrize("mode", [0, 1, 2])
def test_FracRidgeRegressor_predict(nn, pp, bb, fit_intercept, mode):
X, y, coef_ols, pred_ols = make_data(nn, pp, bb, fit_intercept)
fracs = np.arange(.1, 1.1, .1)
FR = FracRidgeRegressor(fracs=fracs, fit_intercept=fit_intercept, jit=jit)
FR = FracRidgeRegressor(
fracs=fracs, fit_intercept=fit_intercept, mode=mode)
FR.fit(X, y)
pred_fr = FR.predict(X)
assert np.allclose(pred_fr[:, -1, ...], pred_ols, atol=10e-3)
Expand All @@ -112,12 +113,12 @@ def test_FracRidge_singleton_frac():
@pytest.mark.parametrize("nn, pp", [(1000, 10), (10, 100), (284, 50)])
@pytest.mark.parametrize("bb", [(1), (2), (1000)])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("jit", [True, False])
def test_FracRidgeRegressorCV(nn, pp, bb, fit_intercept, jit):
@pytest.mark.parametrize("mode", [0, 1, 2])
def test_FracRidgeRegressorCV(nn, pp, bb, fit_intercept, mode):
X, y, _, _ = make_data(nn, pp, bb, fit_intercept)
fracs = np.arange(.1, 1.1, .1)
FRCV = FracRidgeRegressorCV(frac_grid=fracs, fit_intercept=fit_intercept,
jit=jit)
mode=mode)
FRCV.fit(X, y)
FR = FracRidgeRegressor(fracs=FRCV.best_frac_)
FR.fit(X, y)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ scikit-learn
numba
setuptools_scm
pillow
jax