Skip to content

Commit

Permalink
feat: added L01Ball operator
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Dec 18, 2023
1 parent 330bcea commit 769da32
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Orthogonal projections
HyperPlaneBoxProj
IntersectionProj
L0BallProj
L01BallProj
L1BallProj
NuclearBallProj
SimplexProj
Expand Down Expand Up @@ -68,6 +69,7 @@ Convex
Intersection
L0
L0Ball
L01Ball
L1
L1Ball
L2
Expand Down
41 changes: 38 additions & 3 deletions pyproximal/projection/L0.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
from pyproximal.projection import SimplexProj


class L0BallProj():
r"""L0 ball projection.
r""":math:`L_0` ball projection.
Parameters
----------
Expand Down Expand Up @@ -32,4 +31,40 @@ def __call__(self, x):
xshape = x.shape
xf = x.copy().flatten()
xf[np.argsort(np.abs(xf))[:-self.radius]] = 0
return xf.reshape(xshape)
return xf.reshape(xshape)


class L01BallProj():
r""":math:`L_{0,1}` ball projection.
Parameters
----------
radius : :obj:`int`
Radius
Notes
-----
Given an :math:`L_{0,1}` ball defined as:
.. math::
L_{0,1}^{r} =
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1,
||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \}
its orthogonal projection is computed by finding the :math:`r` highest
largest entries of a vector obtained by applying the :math:`L_1` norm to each
column of a matrix :math:`\mathbf{x}` (in absolute value), keeping those
and zero-ing all the other entries.
Note that this is the proximal operator of the corresponding
indicator function :math:`\mathcal{I}_{L_{0,1}^{r}}`.
"""
def __init__(self, radius):
self.radius = int(radius)

def __call__(self, x):
xc = x.copy()
xf = np.linalg.norm(x, axis=0, ord=1)
xc[:, np.argsort(np.abs(xf))[:-self.radius]] = 0
return xc
2 changes: 1 addition & 1 deletion pyproximal/projection/L1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class L1BallProj():
r"""L1 ball projection.
r""":math:`L_1` ball projection.
Parameters
----------
Expand Down
3 changes: 2 additions & 1 deletion pyproximal/projection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HyperPlaneBoxProj Projection onto an intersection beween a HyperPlane and a Box
SimplexProj Projection onto a Simplex
L0Proj Projection onto an L0 Ball
L01Proj Projection onto an L0,1 Ball
L1Proj Projection onto an L1 Ball
EuclideanBallProj Projection onto an Euclidean Ball
NuclearBallProj Projection onto a Nuclear Ball
Expand All @@ -29,5 +30,5 @@


__all__ = ['BoxProj', 'HyperPlaneBoxProj', 'SimplexProj', 'L0BallProj',
'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
'L01BallProj', 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj',
'IntersectionProj', 'AffineSetProj', 'HankelProj']
66 changes: 61 additions & 5 deletions pyproximal/proximal/L0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from pyproximal.ProxOperator import _check_tau
from pyproximal.projection import L0BallProj
from pyproximal.projection import L0BallProj, L01BallProj
from pyproximal import ProxOperator
from pyproximal.proximal.L1 import _current_sigma

Expand Down Expand Up @@ -35,7 +35,7 @@ def _hardthreshold(x, thresh):


class L0(ProxOperator):
r"""L0 norm proximal operator.
r""":math:`L_0` norm proximal operator.
Proximal operator of the :math:`\ell_0` norm:
:math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
Expand Down Expand Up @@ -92,7 +92,7 @@ def prox(self, x, tau):


class L0Ball(ProxOperator):
r"""L0 ball proximal operator.
r""":math:`L_0` ball proximal operator.
Proximal operator of the L0 ball: :math:`L0_{r} =
\{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`.
Expand All @@ -103,7 +103,6 @@ class L0Ball(ProxOperator):
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.
Radius
Notes
-----
Expand Down Expand Up @@ -136,4 +135,61 @@ def prox(self, x, tau):
radius = _current_sigma(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y
return y


class L01Ball(ProxOperator):
r""":math:`L_{0,1}` ball proximal operator.
Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} =
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ...,
||\mathbf{x}_1||_1] \ne 0) \leq r \}`
Parameters
----------
ndim : :obj:`int`
Number of dimensions :math:`N_{dim}`. Used to reshape the input array
in a matrix of size :math:`N_{dim} \times N'_{x}` where
:math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input
vector ``x`` should be created by stacking vectors from different
dimensions.
radius : :obj:`int` or :obj:`func`, optional
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.
Notes
-----
As the L0 ball is an indicator function, the proximal operator
corresponds to its orthogonal projection
(see :class:`pyproximal.projection.L01BallProj` for details.
"""
def __init__(self, ndim, radius):
super().__init__(None, False)
self.ndim = ndim
self.radius = radius
self.ball = L01BallProj(self.radius if not callable(radius) else radius(0))
self.count = 0

def __call__(self, x, tol=1e-4):
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_sigma(self.radius, self.count)
return np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius

def _increment_count(func):
"""Increment counter
"""
def wrapped(self, *args, **kwargs):
self.count += 1
return func(self, *args, **kwargs)
return wrapped

@_increment_count
@_check_tau
def prox(self, x, tau):
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_sigma(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y.ravel()
3 changes: 2 additions & 1 deletion pyproximal/proximal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Nonlinear Nonlinear function
L0 L0 Norm
L0Ball L0 Ball
L01pBall L0,1 Ball
L1 L1 Norm
L1Ball L1 Ball
Euclidean Euclidean Norm
Expand Down Expand Up @@ -67,7 +68,7 @@
from .Hankel import *

__all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2',
'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L01Ball', 'L1', 'L1Ball', 'L2',
'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah',
'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD',
'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty',
Expand Down
21 changes: 20 additions & 1 deletion pytests/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.testing import assert_array_almost_equal
from pylops.basicoperators import Identity
from pyproximal.utils import moreau
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L1Ball, \
from pyproximal.proximal import Box, EuclideanBall, L0Ball, L01Ball, L1Ball, \
NuclearBall, Simplex, AffineSet, Hankel

par1 = {'nx': 10, 'ny': 8, 'axis': 0, 'dtype': 'float32'} # even float32 dir0
Expand Down Expand Up @@ -65,6 +65,25 @@ def test_L0Ball(par):
assert moreau(l0, x, tau)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_L01Ball(par):
"""L01 Ball projection and proximal/dual proximal of related indicator
"""
np.random.seed(10)

l0 = L01Ball(3, 1)
x = np.random.normal(0., 1., (3, par['nx'])).astype(par['dtype']).ravel() + 1.

# evaluation
assert l0(x) == False
xp = l0.prox(x, 1.)
assert l0(xp) == True

# prox / dualprox
tau = 2.
assert moreau(l0, x, tau)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_L1Ball(par):
"""L1 Ball projection and proximal/dual proximal of related indicator
Expand Down

0 comments on commit 769da32

Please sign in to comment.