diff --git a/jaxopt/_src/eq_qp.py b/jaxopt/_src/eq_qp.py index 65fc3936..40fae6b5 100644 --- a/jaxopt/_src/eq_qp.py +++ b/jaxopt/_src/eq_qp.py @@ -25,7 +25,7 @@ from jaxopt._src import base from jaxopt._src import implicit_diff as idf -from jaxopt._src.tree_util import tree_add, tree_sub +from jaxopt._src.tree_util import tree_add, tree_sub, tree_add_scalar_mul from jaxopt._src.tree_util import tree_vdot, tree_negative, tree_l2_norm from jaxopt._src.linear_operator import _make_linear_operator from jaxopt._src.cvxpy_wrapper import _check_params @@ -143,7 +143,9 @@ def matvec_qp(_, x): def matvec_regularized_qp(_, x): primal, dual_eq = x Stop, Sbottom = matvec(x) - return Stop + ridge * primal, Sbottom - ridge * dual_eq + a = tree_add_scalar_mul(Stop, ridge, primal) + b = tree_add_scalar_mul(Sbottom, -ridge, dual_eq) + return a, b solver = IterativeRefinement( matvec_A=matvec_qp, diff --git a/tests/eq_qp_test.py b/tests/eq_qp_test.py index 11fb0200..1edc9124 100644 --- a/tests/eq_qp_test.py +++ b/tests/eq_qp_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import jax import jax.numpy as jnp @@ -108,7 +109,9 @@ def test_projection_hyperplane(self): self.assertArraysAllClose(primal_sol, expected) self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) - def test_eq_constrained_qp_with_pytrees(self): + @parameterized.product(refine_regularization=[0., 1e-6]) + def test_eq_constrained_qp_with_pytrees(self, refine_regularization): + # refine_regularization != 0. triggers a non regression test for issue #311 rng = onp.random.RandomState(0) Q = rng.randn(7, 7) Q = onp.dot(Q, Q.T) @@ -131,7 +134,8 @@ def matvec_A(A, tup): # With pytrees directly. hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) - qp = EqualityConstrainedQP(matvec_Q=matvec_Q, matvec_A=matvec_A) + qp = EqualityConstrainedQP(matvec_Q=matvec_Q, matvec_A=matvec_A, + refine_regularization=refine_regularization) # sol.primal has the same pytree structure as the output of matvec_Q. # sol.dual_eq has the same pytree structure as the output of matvec_A. sol_pytree = qp.run(**hyperparams).params