diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 79fdf2f..085037a 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -24,6 +24,7 @@ Orthogonal projections HyperPlaneBoxProj IntersectionProj L0BallProj + L01BallProj L1BallProj NuclearBallProj SimplexProj @@ -68,6 +69,7 @@ Convex Intersection L0 L0Ball + L01Ball L1 L1Ball L2 diff --git a/pyproximal/projection/L0.py b/pyproximal/projection/L0.py index 736f511..7ad057f 100644 --- a/pyproximal/projection/L0.py +++ b/pyproximal/projection/L0.py @@ -1,9 +1,8 @@ import numpy as np -from pyproximal.projection import SimplexProj class L0BallProj(): - r"""L0 ball projection. + r""":math:`L_0` ball projection. Parameters ---------- @@ -32,4 +31,40 @@ def __call__(self, x): xshape = x.shape xf = x.copy().flatten() xf[np.argsort(np.abs(xf))[:-self.radius]] = 0 - return xf.reshape(xshape) \ No newline at end of file + return xf.reshape(xshape) + + +class L01BallProj(): + r""":math:`L_{0,1}` ball projection. + + Parameters + ---------- + radius : :obj:`int` + Radius + + Notes + ----- + Given an :math:`L_{0,1}` ball defined as: + + .. math:: + + L_{0,1}^{r} = + \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, + ||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \} + + its orthogonal projection is computed by finding the :math:`r` highest + largest entries of a vector obtained by applying the :math:`L_1` norm to each + column of a matrix :math:`\mathbf{x}` (in absolute value), keeping those + and zero-ing all the other entries. + Note that this is the proximal operator of the corresponding + indicator function :math:`\mathcal{I}_{L_{0,1}^{r}}`. + + """ + def __init__(self, radius): + self.radius = int(radius) + + def __call__(self, x): + xc = x.copy() + xf = np.linalg.norm(x, axis=0, ord=1) + xc[:, np.argsort(np.abs(xf))[:-self.radius]] = 0 + return xc \ No newline at end of file diff --git a/pyproximal/projection/L1.py b/pyproximal/projection/L1.py index aeccde7..bde93ae 100644 --- a/pyproximal/projection/L1.py +++ b/pyproximal/projection/L1.py @@ -3,7 +3,7 @@ class L1BallProj(): - r"""L1 ball projection. + r""":math:`L_1` ball projection. Parameters ---------- diff --git a/pyproximal/projection/__init__.py b/pyproximal/projection/__init__.py index 349994d..a64b489 100644 --- a/pyproximal/projection/__init__.py +++ b/pyproximal/projection/__init__.py @@ -8,6 +8,7 @@ HyperPlaneBoxProj Projection onto an intersection beween a HyperPlane and a Box SimplexProj Projection onto a Simplex L0Proj Projection onto an L0 Ball + L01Proj Projection onto an L0,1 Ball L1Proj Projection onto an L1 Ball EuclideanBallProj Projection onto an Euclidean Ball NuclearBallProj Projection onto a Nuclear Ball @@ -29,5 +30,5 @@ __all__ = ['BoxProj', 'HyperPlaneBoxProj', 'SimplexProj', 'L0BallProj', - 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj', + 'L01BallProj', 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj', 'IntersectionProj', 'AffineSetProj', 'HankelProj'] \ No newline at end of file diff --git a/pyproximal/proximal/L0.py b/pyproximal/proximal/L0.py index 2b54eec..1c1052b 100644 --- a/pyproximal/proximal/L0.py +++ b/pyproximal/proximal/L0.py @@ -1,7 +1,7 @@ import numpy as np from pyproximal.ProxOperator import _check_tau -from pyproximal.projection import L0BallProj +from pyproximal.projection import L0BallProj, L01BallProj from pyproximal import ProxOperator from pyproximal.proximal.L1 import _current_sigma @@ -35,7 +35,7 @@ def _hardthreshold(x, thresh): class L0(ProxOperator): - r"""L0 norm proximal operator. + r""":math:`L_0` norm proximal operator. Proximal operator of the :math:`\ell_0` norm: :math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`. @@ -92,7 +92,7 @@ def prox(self, x, tau): class L0Ball(ProxOperator): - r"""L0 ball proximal operator. + r""":math:`L_0` ball proximal operator. Proximal operator of the L0 ball: :math:`L0_{r} = \{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`. @@ -103,7 +103,6 @@ class L0Ball(ProxOperator): Radius. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar ``radius`` to be used. - Radius Notes ----- @@ -136,4 +135,61 @@ def prox(self, x, tau): radius = _current_sigma(self.radius, self.count) self.ball.radius = radius y = self.ball(x) - return y \ No newline at end of file + return y + + +class L01Ball(ProxOperator): + r""":math:`L_{0,1}` ball proximal operator. + + Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} = + \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ..., + ||\mathbf{x}_1||_1] \ne 0) \leq r \}` + + Parameters + ---------- + ndim : :obj:`int` + Number of dimensions :math:`N_{dim}`. Used to reshape the input array + in a matrix of size :math:`N_{dim} \times N'_{x}` where + :math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input + vector ``x`` should be created by stacking vectors from different + dimensions. + radius : :obj:`int` or :obj:`func`, optional + Radius. This can be a constant number or a function that is called passing a + counter which keeps track of how many times the ``prox`` method has been + invoked before and returns a scalar ``radius`` to be used. + + Notes + ----- + As the L0 ball is an indicator function, the proximal operator + corresponds to its orthogonal projection + (see :class:`pyproximal.projection.L01BallProj` for details. + + """ + def __init__(self, ndim, radius): + super().__init__(None, False) + self.ndim = ndim + self.radius = radius + self.ball = L01BallProj(self.radius if not callable(radius) else radius(0)) + self.count = 0 + + def __call__(self, x, tol=1e-4): + x = x.reshape(self.ndim, len(x) // self.ndim) + radius = _current_sigma(self.radius, self.count) + return np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius + + def _increment_count(func): + """Increment counter + """ + def wrapped(self, *args, **kwargs): + self.count += 1 + return func(self, *args, **kwargs) + return wrapped + + @_increment_count + @_check_tau + def prox(self, x, tau): + x = x.reshape(self.ndim, len(x) // self.ndim) + radius = _current_sigma(self.radius, self.count) + self.ball.radius = radius + y = self.ball(x) + return y.ravel() \ No newline at end of file diff --git a/pyproximal/proximal/__init__.py b/pyproximal/proximal/__init__.py index 6d64f2a..26ef53f 100644 --- a/pyproximal/proximal/__init__.py +++ b/pyproximal/proximal/__init__.py @@ -12,6 +12,7 @@ Nonlinear Nonlinear function L0 L0 Norm L0Ball L0 Ball + L01pBall L0,1 Ball L1 L1 Norm L1Ball L1 Ball Euclidean Euclidean Norm @@ -67,7 +68,7 @@ from .Hankel import * __all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic', - 'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2', + 'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L01Ball', 'L1', 'L1Ball', 'L2', 'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah', 'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD', 'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty', diff --git a/pytests/test_projection.py b/pytests/test_projection.py index e68ac93..e2d4f81 100644 --- a/pytests/test_projection.py +++ b/pytests/test_projection.py @@ -4,7 +4,7 @@ from numpy.testing import assert_array_almost_equal from pylops.basicoperators import Identity from pyproximal.utils import moreau -from pyproximal.proximal import Box, EuclideanBall, L0Ball, L1Ball, \ +from pyproximal.proximal import Box, EuclideanBall, L0Ball, L01Ball, L1Ball, \ NuclearBall, Simplex, AffineSet, Hankel par1 = {'nx': 10, 'ny': 8, 'axis': 0, 'dtype': 'float32'} # even float32 dir0 @@ -65,6 +65,25 @@ def test_L0Ball(par): assert moreau(l0, x, tau) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_L01Ball(par): + """L01 Ball projection and proximal/dual proximal of related indicator + """ + np.random.seed(10) + + l0 = L01Ball(3, 1) + x = np.random.normal(0., 1., (3, par['nx'])).astype(par['dtype']).ravel() + 1. + + # evaluation + assert l0(x) == False + xp = l0.prox(x, 1.) + assert l0(xp) == True + + # prox / dualprox + tau = 2. + assert moreau(l0, x, tau) + + @pytest.mark.parametrize("par", [(par1), (par2)]) def test_L1Ball(par): """L1 Ball projection and proximal/dual proximal of related indicator