diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 79fdf2f..085037a 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -24,6 +24,7 @@ Orthogonal projections HyperPlaneBoxProj IntersectionProj L0BallProj + L01BallProj L1BallProj NuclearBallProj SimplexProj @@ -68,6 +69,7 @@ Convex Intersection L0 L0Ball + L01Ball L1 L1Ball L2 diff --git a/pyproximal/projection/L0.py b/pyproximal/projection/L0.py index 736f511..7ad057f 100644 --- a/pyproximal/projection/L0.py +++ b/pyproximal/projection/L0.py @@ -1,9 +1,8 @@ import numpy as np -from pyproximal.projection import SimplexProj class L0BallProj(): - r"""L0 ball projection. + r""":math:`L_0` ball projection. Parameters ---------- @@ -32,4 +31,40 @@ def __call__(self, x): xshape = x.shape xf = x.copy().flatten() xf[np.argsort(np.abs(xf))[:-self.radius]] = 0 - return xf.reshape(xshape) \ No newline at end of file + return xf.reshape(xshape) + + +class L01BallProj(): + r""":math:`L_{0,1}` ball projection. + + Parameters + ---------- + radius : :obj:`int` + Radius + + Notes + ----- + Given an :math:`L_{0,1}` ball defined as: + + .. math:: + + L_{0,1}^{r} = + \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, + ||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \} + + its orthogonal projection is computed by finding the :math:`r` highest + largest entries of a vector obtained by applying the :math:`L_1` norm to each + column of a matrix :math:`\mathbf{x}` (in absolute value), keeping those + and zero-ing all the other entries. + Note that this is the proximal operator of the corresponding + indicator function :math:`\mathcal{I}_{L_{0,1}^{r}}`. + + """ + def __init__(self, radius): + self.radius = int(radius) + + def __call__(self, x): + xc = x.copy() + xf = np.linalg.norm(x, axis=0, ord=1) + xc[:, np.argsort(np.abs(xf))[:-self.radius]] = 0 + return xc \ No newline at end of file diff --git a/pyproximal/projection/L1.py b/pyproximal/projection/L1.py index aeccde7..bde93ae 100644 --- a/pyproximal/projection/L1.py +++ b/pyproximal/projection/L1.py @@ -3,7 +3,7 @@ class L1BallProj(): - r"""L1 ball projection. + r""":math:`L_1` ball projection. Parameters ---------- diff --git a/pyproximal/projection/__init__.py b/pyproximal/projection/__init__.py index 349994d..a64b489 100644 --- a/pyproximal/projection/__init__.py +++ b/pyproximal/projection/__init__.py @@ -8,6 +8,7 @@ HyperPlaneBoxProj Projection onto an intersection beween a HyperPlane and a Box SimplexProj Projection onto a Simplex L0Proj Projection onto an L0 Ball + L01Proj Projection onto an L0,1 Ball L1Proj Projection onto an L1 Ball EuclideanBallProj Projection onto an Euclidean Ball NuclearBallProj Projection onto a Nuclear Ball @@ -29,5 +30,5 @@ __all__ = ['BoxProj', 'HyperPlaneBoxProj', 'SimplexProj', 'L0BallProj', - 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj', + 'L01BallProj', 'L1BallProj', 'EuclideanBallProj', 'NuclearBallProj', 'IntersectionProj', 'AffineSetProj', 'HankelProj'] \ No newline at end of file diff --git a/pyproximal/proximal/L0.py b/pyproximal/proximal/L0.py index 2b54eec..1c1052b 100644 --- a/pyproximal/proximal/L0.py +++ b/pyproximal/proximal/L0.py @@ -1,7 +1,7 @@ import numpy as np from pyproximal.ProxOperator import _check_tau -from pyproximal.projection import L0BallProj +from pyproximal.projection import L0BallProj, L01BallProj from pyproximal import ProxOperator from pyproximal.proximal.L1 import _current_sigma @@ -35,7 +35,7 @@ def _hardthreshold(x, thresh): class L0(ProxOperator): - r"""L0 norm proximal operator. + r""":math:`L_0` norm proximal operator. Proximal operator of the :math:`\ell_0` norm: :math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`. @@ -92,7 +92,7 @@ def prox(self, x, tau): class L0Ball(ProxOperator): - r"""L0 ball proximal operator. + r""":math:`L_0` ball proximal operator. Proximal operator of the L0 ball: :math:`L0_{r} = \{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`. @@ -103,7 +103,6 @@ class L0Ball(ProxOperator): Radius. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar ``radius`` to be used. - Radius Notes ----- @@ -136,4 +135,61 @@ def prox(self, x, tau): radius = _current_sigma(self.radius, self.count) self.ball.radius = radius y = self.ball(x) - return y \ No newline at end of file + return y + + +class L01Ball(ProxOperator): + r""":math:`L_{0,1}` ball proximal operator. + + Proximal operator of the :math:`L_{0,1}` ball: :math:`L_{0,1}^{r} = + \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ..., + ||\mathbf{x}_1||_1] \ne 0) \leq r \}` + + Parameters + ---------- + ndim : :obj:`int` + Number of dimensions :math:`N_{dim}`. Used to reshape the input array + in a matrix of size :math:`N_{dim} \times N'_{x}` where + :math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input + vector ``x`` should be created by stacking vectors from different + dimensions. + radius : :obj:`int` or :obj:`func`, optional + Radius. This can be a constant number or a function that is called passing a + counter which keeps track of how many times the ``prox`` method has been + invoked before and returns a scalar ``radius`` to be used. + + Notes + ----- + As the L0 ball is an indicator function, the proximal operator + corresponds to its orthogonal projection + (see :class:`pyproximal.projection.L01BallProj` for details. + + """ + def __init__(self, ndim, radius): + super().__init__(None, False) + self.ndim = ndim + self.radius = radius + self.ball = L01BallProj(self.radius if not callable(radius) else radius(0)) + self.count = 0 + + def __call__(self, x, tol=1e-4): + x = x.reshape(self.ndim, len(x) // self.ndim) + radius = _current_sigma(self.radius, self.count) + return np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius + + def _increment_count(func): + """Increment counter + """ + def wrapped(self, *args, **kwargs): + self.count += 1 + return func(self, *args, **kwargs) + return wrapped + + @_increment_count + @_check_tau + def prox(self, x, tau): + x = x.reshape(self.ndim, len(x) // self.ndim) + radius = _current_sigma(self.radius, self.count) + self.ball.radius = radius + y = self.ball(x) + return y.ravel() \ No newline at end of file diff --git a/pyproximal/proximal/__init__.py b/pyproximal/proximal/__init__.py index 6d64f2a..26ef53f 100644 --- a/pyproximal/proximal/__init__.py +++ b/pyproximal/proximal/__init__.py @@ -12,6 +12,7 @@ Nonlinear Nonlinear function L0 L0 Norm L0Ball L0 Ball + L01pBall L0,1 Ball L1 L1 Norm L1Ball L1 Ball Euclidean Euclidean Norm @@ -67,7 +68,7 @@ from .Hankel import * __all__ = ['Box', 'Simplex', 'Intersection', 'AffineSet', 'Quadratic', - 'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L1', 'L1Ball', 'L2', + 'Euclidean', 'EuclideanBall', 'L0', 'L0Ball', 'L01Ball', 'L1', 'L1Ball', 'L2', 'L2Convolve', 'L21', 'L21_plus_L1', 'Huber', 'TV', 'RelaxedMumfordShah', 'Nuclear', 'NuclearBall', 'Orthogonal', 'VStack', 'Nonlinear', 'SCAD', 'Log', 'Log1', 'ETP', 'Geman', 'QuadraticEnvelopeCard', 'SingularValuePenalty', diff --git a/pytests/test_projection.py b/pytests/test_projection.py index e68ac93..e2d4f81 100644 --- a/pytests/test_projection.py +++ b/pytests/test_projection.py @@ -4,7 +4,7 @@ from numpy.testing import assert_array_almost_equal from pylops.basicoperators import Identity from pyproximal.utils import moreau -from pyproximal.proximal import Box, EuclideanBall, L0Ball, L1Ball, \ +from pyproximal.proximal import Box, EuclideanBall, L0Ball, L01Ball, L1Ball, \ NuclearBall, Simplex, AffineSet, Hankel par1 = {'nx': 10, 'ny': 8, 'axis': 0, 'dtype': 'float32'} # even float32 dir0 @@ -65,6 +65,25 @@ def test_L0Ball(par): assert moreau(l0, x, tau) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_L01Ball(par): + """L01 Ball projection and proximal/dual proximal of related indicator + """ + np.random.seed(10) + + l0 = L01Ball(3, 1) + x = np.random.normal(0., 1., (3, par['nx'])).astype(par['dtype']).ravel() + 1. + + # evaluation + assert l0(x) == False + xp = l0.prox(x, 1.) + assert l0(xp) == True + + # prox / dualprox + tau = 2. + assert moreau(l0, x, tau) + + @pytest.mark.parametrize("par", [(par1), (par2)]) def test_L1Ball(par): """L1 Ball projection and proximal/dual proximal of related indicator diff --git a/tutorials/groupsparsity.py b/tutorials/groupsparsity.py new file mode 100644 index 0000000..70822bd --- /dev/null +++ b/tutorials/groupsparsity.py @@ -0,0 +1,160 @@ +r""" +Group sparsity +============== +This notebooks considers the problem of jointly interpolating N (e.g., 2) signals +with sparse representation in the frequency domain and shows the importance of applying +a group sparsity constraint by means of the :class:`pyproximal.proximal.L01Ball` proximal operator. + +Given the following problem: + +.. math:: + [\mathbf{y}_1^T, \mathbf{y}_2^T,\ldots,\mathbf{y}_N^T]^T = + diag(\mathbf{R}, \mathbf{R}, ..., \mathbf{R}) + [\mathbf{x}_1^T, \mathbf{x}_2^T, \ldots,\mathbf{x}_N^T]^T \rightarrow \mathbf{y}=\mathbf{R}_N\mathbf{x} , + +we aim to find a solution to this objective function: + +.. math:: + J = \frac{1}{2} ||\mathbf{y} - \mathbf{R}_N \mathbf{x}||_2^2 \; s.t. ||\mathbf{X}||_{0,1} < K + + +where :math:`\mathbf{X}` is a matrix whose rows are represented by the different +signals :math:`\mathbf{x}_i`, and the :math:`L_{0,1}` norm computes the number of non-zero elements of +a vector whose elements are the $L_1$ norm of each column of :math:`\mathbf{X}`. + +""" +import numpy as np +import matplotlib.pyplot as plt +import pylops + +import pyproximal + +plt.close('all') +np.random.seed(10) + +############################################################################### +# Let's first create 2 signals in the frequency domain composed by the +# superposition of 3 sinusoids with different frequencies. +ifreqs = [4, 8, 11] +amps1 = [1.0, 0.2, 0.5] +amps2 = [3.0, 3.0, 2.0] + +N = 2 ** 8 +nfft = N +dt = 0.004 +t = np.arange(N) * dt +f = np.fft.rfftfreq(nfft, dt) + +FFTop = 10 * pylops.signalprocessing.FFT(N, nfft=nfft, real=True) + +X1 = np.zeros(nfft // 2 + 1, dtype='complex128') +X2 = np.zeros(nfft // 2 + 1, dtype='complex128') +X1[ifreqs] = amps1 +X2[ifreqs] = amps2 + +x1 = FFTop.H * X1 +x2 = FFTop.H * X2 + +fig, axs = plt.subplots(2, 1, figsize=(12, 8)) +axs[0].plot(f, np.abs(X1), 'k', lw=2) +axs[0].plot(f, np.abs(X2), 'r', lw=2) +axs[0].set_xlim(0, 30) +axs[0].set_title('Data (frequency domain)') +axs[1].plot(t, x1, 'k', lw=2) +axs[1].plot(t, x2, 'r', lw=2) +axs[1].set_title('Data (time domain)') +axs[1].axis('tight') +plt.tight_layout() + +############################################################################### +# We now define the locations at which the signals will be sampled. The first +# signal is severely subsampled (10% of available samples), whilst the second +# dataset retains 60% of its samples. This choice is made on purpose to see +# if group sparsity could help interpolating the first signal by leveraging +# the fact that it is easier to interpolate the second signal +np.random.seed(10) + +perc_subsampling = (0.1, 0.6) +Nsub1, Nsub2 = int(np.round(N * perc_subsampling[0])), int(np.round(N * perc_subsampling[1])) +iava1 = np.sort(np.random.permutation(np.arange(N))[:Nsub1]) +iava2 = np.sort(np.random.permutation(np.arange(N))[:Nsub2]) + +# Create restriction operator +Rop1 = pylops.Restriction(N, iava1, dtype='float64') +Rop2 = pylops.Restriction(N, iava2, dtype='float64') + +y1 = Rop1 * x1 +y2 = Rop2 * x2 + +Op1 = Rop1 * FFTop.H +Op2 = Rop2 * FFTop.H + +X1adj = Op1.H * y1 +X2adj = Op2.H * y2 + +############################################################################### +# Let's try to interpolate the first signal +L = np.abs((Op1.H * Op1).eigs(1)[0]) +eps = 1 # not used given that a projection is used as regularizer +niter = 400 +tau = 0.95 / L + +l0 = pyproximal.proximal.L0Ball(3) +l2 = pyproximal.proximal.L2(Op=Op1, b=y1) +X1est = pyproximal.optimization.primal.ProximalGradient( + l2, l0, tau=tau, x0=np.zeros(nfft // 2 + 1, dtype='complex128'), + epsg=eps, niter=niter, acceleration='fista', show=False) +x1est = FFTop.H * X1est + +fig, axs = plt.subplots(1, 2, sharey=True, figsize=(12, 3)) +axs[0].plot(np.abs(X1), 'k', lw=4, label='Original') +axs[0].plot(np.abs(X1est), '--b', lw=2, label='Rec') +axs[0].set_title('Data (frequency domain)') +axs[0].set_xlim(0, 30) +axs[1].plot(t, x1, 'k', lw=4, label='Original') +axs[1].plot(t, x1est, '--b', lw=2, label='Rec') +axs[1].set_title('Data (time domain)') +axs[1].legend() +plt.tight_layout() + +############################################################################### +# And now we interpolate the two signals together +Opp = pylops.BlockDiag([Op1, Op2]) +yy = np.hstack([y1, y2]) + +L = np.abs((Opp.H * Opp).eigs(1)[0]) +eps = 1 # not used given that a projection is used as regularizer +niter = 400 +tau= 0.99 / L + +l0 = pyproximal.proximal.L01Ball(ndim=2, radius=4) +l2 = pyproximal.proximal.L2(Op=Opp, b=yy) + +XXest = pyproximal.optimization.primal.ProximalGradient( + l2, l0, tau=tau, x0=np.zeros(2*(nfft // 2 + 1), dtype='complex128'), + epsg=eps, niter=niter, acceleration='fista', show=False) + +X1est, X2est = XXest[:FFTop.shape[0]], XXest[FFTop.shape[0]:] +x1est = FFTop.H * X1est +x2est = FFTop.H * X2est + +fig, axs = plt.subplots(1, 2, sharey=True, figsize=(14, 3)) +axs[0].plot(np.abs(X1), 'k', lw=4, label='Original') +axs[0].plot(np.abs(X1est), '--b', lw=2, label='Rec') +axs[0].set_title('First data') +axs[1].plot(np.abs(X2), 'k', lw=4) +axs[1].plot(np.abs(X2est), '--b', lw=2) +axs[0].set_xlim(0, 30) +axs[1].set_xlim(0, 30) +plt.tight_layout() + +fig, axs = plt.subplots(1, 2, sharey=True, figsize=(14, 3)) +axs[0].plot(t, x1, 'k', lw=4, label='Original') +axs[0].plot(t, x1est, '--b', lw=2, label='Rec') +axs[0].set_title('First data') +axs[0].legend() +axs[1].plot(t, x2, 'k', lw=4) +axs[1].plot(t, x2est, '--b', lw=2) +axs[1].set_title('Second data') +plt.tight_layout() +