Skip to content

Commit

Permalink
Merge pull request #345 from Algue-Rythme:eqqppytreerefine
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 488297091
  • Loading branch information
JAXopt authors committed Nov 14, 2022
2 parents 8fe250d + 62d7eb8 commit 3828913
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 4 additions & 2 deletions jaxopt/_src/eq_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src.tree_util import tree_add, tree_sub
from jaxopt._src.tree_util import tree_add, tree_sub, tree_add_scalar_mul
from jaxopt._src.tree_util import tree_vdot, tree_negative, tree_l2_norm
from jaxopt._src.linear_operator import _make_linear_operator
from jaxopt._src.cvxpy_wrapper import _check_params
Expand Down Expand Up @@ -143,7 +143,9 @@ def matvec_qp(_, x):
def matvec_regularized_qp(_, x):
primal, dual_eq = x
Stop, Sbottom = matvec(x)
return Stop + ridge * primal, Sbottom - ridge * dual_eq
a = tree_add_scalar_mul(Stop, ridge, primal)
b = tree_add_scalar_mul(Sbottom, -ridge, dual_eq)
return a, b

solver = IterativeRefinement(
matvec_A=matvec_qp,
Expand Down
8 changes: 6 additions & 2 deletions tests/eq_qp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -108,7 +109,9 @@ def test_projection_hyperplane(self):
self.assertArraysAllClose(primal_sol, expected)
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)

def test_eq_constrained_qp_with_pytrees(self):
@parameterized.product(refine_regularization=[0., 1e-6])
def test_eq_constrained_qp_with_pytrees(self, refine_regularization):
# refine_regularization != 0. triggers a non regression test for issue #311
rng = onp.random.RandomState(0)
Q = rng.randn(7, 7)
Q = onp.dot(Q, Q.T)
Expand All @@ -131,7 +134,8 @@ def matvec_A(A, tup):

# With pytrees directly.
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
qp = EqualityConstrainedQP(matvec_Q=matvec_Q, matvec_A=matvec_A)
qp = EqualityConstrainedQP(matvec_Q=matvec_Q, matvec_A=matvec_A,
refine_regularization=refine_regularization)
# sol.primal has the same pytree structure as the output of matvec_Q.
# sol.dual_eq has the same pytree structure as the output of matvec_A.
sol_pytree = qp.run(**hyperparams).params
Expand Down

0 comments on commit 3828913

Please sign in to comment.