Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add improved version of Hungarian algorithm. #1140

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 7.5166864
Objective function: 7.516686...
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Expand Down
181 changes: 179 additions & 2 deletions optax/assignment/_hungarian_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import functools

import jax
import jax.numpy as jnp
from jax import lax, numpy as jnp


def hungarian_algorithm(cost_matrix):
def base_hungarian_algorithm(cost_matrix):
r"""The Hungarian algorithm for the linear assignment problem.

In `this problem <https://en.wikipedia.org/wiki/Linear_assignment_problem>`_,
Expand Down Expand Up @@ -353,3 +353,180 @@ def augment(carry):
)

return costs, u, v, path, row4col, col4row


def _masked_argmin(array, mask):
array = jnp.where(mask, array, jnp.inf)
assert isinstance(array, jax.Array)
return jnp.argmin(array)


def hungarian_algorithm(cost_matrix):
r"""The Hungarian algorithm for the linear assignment problem.

In `this problem <https://en.wikipedia.org/wiki/Linear_assignment_problem>`_,
we are given an :math:`n \times m` cost matrix. The goal is to compute an
assignment, i.e. a set of pairs of rows and columns, in such a way that:

- At most one column is assigned to each row.
- At most one row is assigned to each column.
- The total number of assignments is :math:`\min(n, m)`.
- The assignment minimizes the sum of costs.

Equivalently, given a weighted complete bipartite graph, the problem is to
find a maximum-cardinality matching that minimizes the sum of the weights of
the edges included in the matching.

Formally, the problem is as follows. Given :math:`C \in \mathbb{R}^{n \times m
}`, solve the following `integer linear program <https://en.wikipedia.org/wiki
/Integer_linear_program>`_:

.. math::

\begin{align*}
\text{minimize} \quad & \sum_{i \in [n]} \sum_{j \in [m]} C_{ij} X_{ij}
\\ \text{subject to} \quad
& X_{ij} \in \{0, 1\} & \forall i \in [n], j \in [m] \\
& \sum_{i \in [n]} X_{ij} \leq 1 & \forall j \in [m] \\
& \sum_{j \in [m]} X_{ij} \leq 1 & \forall i \in [n] \\
& \sum_{i \in [n]} \sum_{j \in [m]} X_{ij} = \min(n, m)
\end{align*}

The `Hungarian algorithm <https://en.wikipedia.org/wiki/Hungarian_algorithm>`_
is a cubic-time algorithm that solves this problem.

This implementation is based on that of the Scenic library (see references).

Unlike `base_hungarian_algorithm`, this version yields a simpler Jaxpr and
appears to be faster.

Args:
cost_matrix: A matrix of costs.

Returns:
A pair ``(i, j)`` where ``i`` is an array of row indices and ``j`` is an
array of column indices.
The cost of the assignment is ``cost_matrix[i, j].sum()``.

Examples:
>>> import optax
>>> from jax import numpy as jnp
>>> cost = jnp.array(
... [
... [8, 4, 7],
... [5, 2, 3],
... [9, 6, 7],
... [9, 4, 8],
... ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 15
>>> cost = jnp.array(
... [
... [90, 80, 75, 70],
... [35, 85, 55, 65],
... [125, 95, 90, 95],
... [45, 110, 95, 115],
... [50, 100, 90, 100],
... ])
>>> i, j = optax.assignment.hungarian_algorithm(cost)
>>> print("cost:", cost[i, j].sum())
cost: 265

References:
Dehghani et al., `Scenic: A JAX Library for Computer Vision Research and
Beyond <https://arxiv.org/abs/2110.11403>`_, 2022
"""

def row_fn(state, row):

def dfs_body_fn(state):
u, v, used, minv, path, col = state

# mark column as used
used = used.at[col].set(True)
unused_slice = ~used[1:]

row = parent[col]

# update minv and path to it
cur = cost_matrix[row - 1, :] - u[row] - v[1:]
cur = jnp.where(unused_slice, cur, jnp.inf)
path = jnp.where(cur < minv, col, path)
minv = jnp.where(cur < minv, cur, minv) # type: ignore

# mask out the visited rows
col = _masked_argmin(minv, unused_slice) + 1
delta = minv.min(where=unused_slice, initial=jnp.inf)

# update potentials
indices = jnp.where(used, parent, rows + 1) # out-of-bounds
u = u.at[indices].add(delta)
v = jnp.where(used, v - delta, v)
minv = jnp.where(unused_slice, minv - delta, minv)

return u, v, used, minv, path, col

def dfs_cond_fn(state):
_, _, _, _, _, col = state
return parent[col] != 0

def back_body_fn(state):
parent, old_col = state
new_col = path[old_col - 1]
parent = parent.at[old_col].set(parent[new_col])
return parent, new_col

def back_cond_fn(state):
_, col = state
return col != 0

u, v, parent = state
parent = parent.at[0].set(row + 1)

# run the inner while loop (i.e. DFS)
path = jnp.zeros(cols, int)
used = jnp.zeros(cols + 1, bool)
minv = jnp.full(cols, jnp.inf) # support array
col = 0

# update parents based on the DFS path
state = u, v, used, minv, path, col
state = lax.while_loop(dfs_cond_fn, dfs_body_fn, state)
u, v, _, _, path, col = state

# backtrack the DFS path
parent, _ = lax.while_loop(back_cond_fn, back_body_fn, (parent, col))

return (u, v, parent), None

if cost_matrix.shape[0] == 0 or cost_matrix.shape[1] == 0:
return jnp.zeros(0, int), jnp.zeros(0, int)

transpose = cost_matrix.shape[0] > cost_matrix.shape[1]

if transpose:
cost_matrix = cost_matrix.T

rows, cols = cost_matrix.shape

u = jnp.zeros(rows + 2) # row potential
v = jnp.zeros(cols + 1) # column potential
parent = jnp.zeros(cols + 1, int) # maps columns to rows

# loop over the rows of the cost matrix
(u, v, parent), _ = lax.scan(row_fn, (u, v, parent), jnp.arange(rows))
# -v[0] is the matching cost

# top_k is costly, so skip it when possible (i.e. for square matrices)
if rows == cols:
parent, indices = parent[1:], jnp.arange(rows)
else:
parent, indices = lax.top_k(parent[1:], rows)

parent -= 1 # switch back to 0-based indexing

if transpose:
return indices, parent

return parent, indices
20 changes: 13 additions & 7 deletions optax/assignment/_hungarian_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@
import jax
import jax.numpy as jnp
import jax.random as jrd
from optax.assignment import _hungarian_algorithm
import scipy

from ._hungarian_algorithm import hungarian_algorithm, base_hungarian_algorithm


class HungarianAlgorithmTest(parameterized.TestCase):

@parameterized.product(
fn=[hungarian_algorithm, base_hungarian_algorithm],
n=[0, 1, 2, 4, 8, 16],
m=[0, 1, 2, 4, 8, 16],
)
def test_hungarian_algorithm(self, n, m):
def test_hungarian_algorithm(self, fn, n, m):
key = jrd.key(0)
costs = jrd.normal(key, (n, m))

i, j = _hungarian_algorithm.hungarian_algorithm(costs)
i, j = fn(costs)

r = min(costs.shape)

Expand Down Expand Up @@ -86,16 +88,17 @@ def test_hungarian_algorithm(self, n, m):
assert jnp.isclose(cost_optax, cost_scipy)

@parameterized.product(
fn=[hungarian_algorithm, base_hungarian_algorithm],
k=[0, 1, 2, 4],
n=[0, 1, 2, 4],
m=[0, 1, 2, 4],
)
def test_hungarian_algorithm_vmap(self, k, n, m):
def test_hungarian_algorithm_vmap(self, fn, k, n, m):
key = jrd.key(0)
costs = jrd.normal(key, (k, n, m))

with self.subTest('works under vmap'):
i, j = jax.vmap(_hungarian_algorithm.hungarian_algorithm)(costs)
i, j = jax.vmap(fn)(costs)

r = min(costs.shape[1:])

Expand All @@ -105,12 +108,15 @@ def test_hungarian_algorithm_vmap(self, k, n, m):
with self.subTest('batch j has correct shape'):
assert j.shape == (k, r)

def test_hungarian_algorithm_jit(self):
@parameterized.product(
fn=[hungarian_algorithm, base_hungarian_algorithm],
)
def test_hungarian_algorithm_jit(self, fn):
key = jrd.key(0)
costs = jrd.normal(key, (20, 30))

with self.subTest('works under jit'):
i, j = jax.jit(_hungarian_algorithm.hungarian_algorithm)(costs)
i, j = jax.jit(fn)(costs)

r = min(costs.shape)

Expand Down
Loading