Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linear program solver based on the restarted Halpern primal-dual hybrid gradient (rHPDHG) algorithm. #1154

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ disable=R,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
invalid-name,


[REPORTS]
Expand Down
12 changes: 12 additions & 0 deletions docs/api/linprog.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Linear programming
==================

.. currentmodule:: optax.linprog

.. autosummary::
rhpdhg


Restarted Halpern primal-dual hybrid gradient method
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: rhpdhg
19 changes: 18 additions & 1 deletion docs/gallery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
.. only:: html

.. image:: /images/examples/linear_assignment_problem.png
:alt:
:alt: Linear assignment problem.

:doc:`_collections/examples/linear_assignment_problem`

Expand All @@ -219,6 +219,23 @@
</div>


.. raw:: html

<div class="sphx-glr-thumbcontainer" tooltip="Linear programming.">

.. only:: html

.. image:: /images/examples/linear_programming.png
:alt: Linear programming.

:doc:`_collections/examples/linear_programming`

.. raw:: html

<div class="sphx-glr-thumbnail-title">Linear programming.</div>
</div>


.. raw:: html

</div>
Expand Down
Binary file added docs/images/examples/linear_programming.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ for instructions on installing JAX.
:caption: 📖 Reference
:maxdepth: 2

api/linprog
api/assignment
api/optimizers
api/transformations
Expand Down
229 changes: 229 additions & 0 deletions examples/linear_programming.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from optax import assignment
from optax import contrib
from optax import linprog
from optax import losses
from optax import monte_carlo
from optax import perturbations
Expand Down Expand Up @@ -364,6 +365,7 @@
"lion",
"linear_onecycle_schedule",
"linear_schedule",
"linprog",
"log_cosh",
"lookahead",
"LookaheadParams",
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 7.5166864
Objective function: 7.516686...
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Expand Down
19 changes: 19 additions & 0 deletions optax/linprog/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The linear programming sub-package."""

# pylint:disable=g-importing-member

from optax.linprog._rhpdhg import solve_general as rhpdhg
215 changes: 215 additions & 0 deletions optax/linprog/_rhpdhg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The restarted Halpern primal-dual hybrid gradient method."""

from jax import lax, numpy as jnp
from optax import tree_utils as otu


def solve_canonical(
c, A, b, iters, reflect=True, restarts=True, tau=None, sigma=None
):
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
gradient (RHPDHG) method.

Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`x \geq 0`.

See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.

Args:
c: Cost vector.
A: Equality constraint matrix.
b: Equality constraint vector.
iters: Number of iterations to run the solver for.
reflect: Use reflection. See paper for details.
restarts: Use restarts. See paper for details.
tau: Primal step size. See paper for details.
sigma: Dual step size. See paper for details.

Returns:
A dictionary whose entries are as follows:
- primal: The final primal solution.
- dual: The final dual solution.
- primal_iterates: The primal iterates.
- dual_iterates: The dual iterates.

Examples:
>>> from jax import numpy as jnp
>>> import optax
>>> c = -jnp.array([2, 1])
>>> A = jnp.zeros([0, 2])
>>> b = jnp.zeros(0)
>>> G = jnp.array([[3, 1], [1, 1], [1, 4]])
>>> h = jnp.array([21, 9, 24])
>>> x = optax.linprog.rhpdhg(c, A, b, G, h, 1_000_000)['primal']
>>> print(x[0])
5.99...
>>> print(x[1])
2.99...

References:
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
<https://arxiv.org/abs/2407.16144>`_, 2024
Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX
<https://arxiv.org/abs/2412.09734>`_, 2024
"""

if tau is None or sigma is None:
A_norm = jnp.linalg.norm(A, axis=(0, 1), ord=2)
if tau is None:
tau = 1 / (2 * A_norm)
if sigma is None:
sigma = 1 / (2 * A_norm)

def T(z):
# primal dual hybrid gradient (PDHG)
x, y = z
xn = x + tau * (y @ A - c)
xn = xn.clip(min=0)
yn = y + sigma * (b - A @ (2 * xn - x))
return xn, yn

def H(z, k, z0):
# Halpern PDHG
Tz = T(z)
if reflect:
zc = otu.tree_sub(otu.tree_scalar_mul(2, Tz), z)
else:
zc = Tz
kp2 = k + 2
zn = otu.tree_add(
otu.tree_scalar_mul((k + 1) / kp2, zc),
otu.tree_scalar_mul(1 / kp2, z0),
)
return zn, Tz

def update(carry, _):
z, k, z0, d0 = carry
zn, Tz = H(z, k, z0)

if restarts:
d = otu.tree_l2_norm(otu.tree_sub(z, Tz), squared=True)
restart = d <= d0 * jnp.exp(-2)
new_carry = otu.tree_where(
restart,
(zn, 0, zn, d),
(zn, k + 1, z0, d0),
)
else:
new_carry = zn, k + 1, z0, d0

return new_carry, z

def run():
m, n = A.shape
x = jnp.zeros(n)
y = jnp.zeros(m)
z0 = x, y
d0 = otu.tree_l2_norm(otu.tree_sub(z0, T(z0)), squared=True)
(z, _, _, _), zs = lax.scan(update, (z0, 0, z0, d0), length=iters)
x, y = z
xs, ys = zs
return {
"primal": x,
"dual": y,
"primal_iterates": xs,
"dual_iterates": ys,
}

return run()


def general_to_canonical(c, A, b, G, h):
"""Converts a linear program from general form to canonical form.

The solution to the new linear program will consist of the concatenation of
- the positive part of x
- the negative part of x
- slacks

That is, we go from

Minimize c · x subject to
A x = b
G x ≤ h

to

Minimize c · (x⁺ - x⁻) subject to
A (x⁺ - x⁻) = b
G (x⁺ - x⁻) + s = h
x⁺, x⁻, s ≥ 0

Args:
c: Cost vector.
A: Equality constraint matrix.
b: Equality constraint vector.
G: Inequality constraint matrix.
h: Inequality constraint vector.

Returns:
A triple (c', A', b') representing the corresponding canonical form.
"""
c_can = jnp.concatenate([c, -c, jnp.zeros(h.size)])
G_ = jnp.concatenate([G, -G, jnp.eye(h.size)], 1)
A_ = jnp.concatenate([A, -A, jnp.zeros([b.size, h.size])], 1)
A_can = jnp.concatenate([A_, G_], 0)
b_can = jnp.concatenate([b, h])
return c_can, A_can, b_can


def solve_general(
c, A, b, G, h, iters, reflect=True, restarts=True, tau=None, sigma=None
):
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
gradient (RHPDHG) method.

Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`G x \leq h`.

See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.

Args:
c: Cost vector.
A: Equality constraint matrix.
b: Equality constraint vector.
G: Inequality constraint matrix.
h: Inequality constraint vector.
iters: Number of iterations to run the solver for.
reflect: Use reflection. See paper for details.
restarts: Use restarts. See paper for details.
tau: Primal step size. See paper for details.
sigma: Dual step size. See paper for details.

Returns:
A dictionary whose entries are as follows:
- primal: The final primal solution.
- slacks: The final primal slack values.
- canonical_result: The result for the canonical program that was used
internally to find this solution. See paper for details.

References:
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
<https://arxiv.org/abs/2407.16144>`_, 2024
Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX
<https://arxiv.org/abs/2412.09734>`_, 2024
"""
canonical = general_to_canonical(c, A, b, G, h)
result = solve_canonical(*canonical, iters, reflect, restarts, tau, sigma)
x_pos, x_neg, slacks = jnp.split(result["primal"], [c.size, c.size * 2])
return {
"primal": x_pos - x_neg,
"slacks": slacks,
"canonical_result": result,
}
Loading
Loading