Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement scipy.optimize.linear_sum_assignment #10403

Closed
carlosgmartin opened this issue Apr 21, 2022 · 8 comments
Closed

Implement scipy.optimize.linear_sum_assignment #10403

carlosgmartin opened this issue Apr 21, 2022 · 8 comments
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@carlosgmartin
Copy link
Contributor

Implement scipy.optimize.linear_sum_assignment, which solves the assignment problem. Among other things, this is useful for estimating the Wasserstein distance between two distributions based on their empirical measures.

@carlosgmartin carlosgmartin added the enhancement New feature or request label Apr 21, 2022
@hawkinsp hawkinsp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR labels Apr 25, 2022
@avinashsai
Copy link

@hawkinsp I can take this up. Can you please provide bit more context where to add the function and any reference if possible?

@riversdark
Copy link

Anyone working on this? It would be nice to have this in JAX.

For reference there is this paper On implementing 2D rectangular assignment algorithms mentioned in the SciPy documentation, that reviews many related algorithms, and a more recent paper, A Fast Scalable Solver for the Dense Linear (Sum) Assignment Problem that attempts to parallelise the algorithm.

@carlosgmartin
Copy link
Contributor Author

@avinashsai @riversdark Here is scipy's C++ implementation. I ported it to JAX, though it's not in fully JITable form yet:

from itertools import count

from jax import numpy as jnp, random, jit
from jax.lax import cond, while_loop
from scipy.optimize import linear_sum_assignment

def augmenting_path(cost, u, v, path, row4col, i):
    minVal = 0
    num_remaining = cost.shape[1]
    remaining = jnp.arange(cost.shape[1])[::-1]

    SR = jnp.full(cost.shape[0], False)
    SC = jnp.full(cost.shape[1], False)
    shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)

    sink = -1
    while sink == -1:
        index = -1
        lowest = jnp.inf
        SR = SR.at[i].set(True)

        for it in range(num_remaining):
            j = remaining[it]

            r = minVal + cost[i, j] - u[i] - v[j]

            path = cond(
                r < shortestPathCosts[j],
                lambda: path.at[j].set(i),
                lambda: path
            )
            shortestPathCosts = shortestPathCosts.at[j].min(r)

            index = cond(
                (shortestPathCosts[j] < lowest) | 
                ((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
                lambda: it,
                lambda: index
            )
            lowest = jnp.minimum(lowest, shortestPathCosts[j])

        minVal = lowest
        if minVal == jnp.inf: # infeasible cost matrix
            sink = -1
            break

        j = remaining[index]

        pred = row4col[j] == -1
        sink = cond(pred, lambda: j, lambda: sink)
        i = cond(~pred, lambda: row4col[j], lambda: i)

        SC = SC.at[j].set(True)
        num_remaining -= 1
        remaining = remaining.at[index].set(remaining[num_remaining])

    return sink, minVal, remaining, SR, SC, shortestPathCosts, path

def solve(cost):
    transpose = cost.shape[1] < cost.shape[0]

    if transpose:
        cost = cost.T

    u = jnp.full(cost.shape[0], 0.)
    v = jnp.full(cost.shape[1], 0.)
    path = jnp.full(cost.shape[1], -1)
    col4row = jnp.full(cost.shape[0], -1)
    row4col = jnp.full(cost.shape[1], -1)

    for curRow in range(cost.shape[0]):

        j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)

        u = u.at[curRow].add(minVal)

        mask = SR & (jnp.arange(cost.shape[0]) != curRow)
        u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])

        v = v.at[SC].add(shortestPathCosts[SC] - minVal)

        while True:
            i = path[j]
            row4col = row4col.at[j].set(i)

            col4row, j = col4row.at[i].set(j), col4row[i]

            if i == curRow:
                break

    if transpose:
        v = col4row.argsort()
        return col4row[v], v
    else:
        return jnp.arange(cost.shape[0]), col4row

def main():
    key = random.PRNGKey(0)
    for t in count():
        key, subkey = random.split(key)
        shape = random.randint(subkey, [2], 0, 6)

        key, subkey = random.split(key)
        cost = random.uniform(subkey, shape)

        if t < 0: # skip to failing case
            continue

        row_ind_1, col_ind_1 = linear_sum_assignment(cost)
        row_ind_2, col_ind_2 = solve(cost)

        print('{:5} {}'.format(t,
            (row_ind_1 == row_ind_2).all() and 
            (col_ind_1 == col_ind_2).all()
        ))

if __name__ == '__main__':
    main()

@rdilip
Copy link

rdilip commented Nov 7, 2022

Any updates on this? This seems particularly important to have for set to set machine learning methods (eg detr).

@carlosgmartin
Copy link
Contributor Author

@avinashsai Are you still interested in implementing this?

@CoastEgo
Copy link

CoastEgo commented Aug 15, 2024

In case you need it, I modified the code by @carlosgmartin to make this JITable. I use this for a (5,5) square cost matrix so I comment out the transpose operation.

from itertools import count
from jax import numpy as jnp, random, jit
from jax import lax
import jax
from scipy.optimize import linear_sum_assignment
import time

@jax.jit
def augmenting_path(cost, u, v, path, row4col, i):
    minVal = 0
    remaining = jnp.arange(cost.shape[1])[::-1]
    num_remaining = cost.shape[1]
    SR = jnp.full(cost.shape[0], False)
    SC = jnp.full(cost.shape[1], False)
    shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)

    sink = -1
    break_cond=False
    def cond_fun(carry):
        sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
        return (sink == -1)&(~break_cond)
    def while_loop_body(carry):
        sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
        index = -1
        lowest = jnp.inf
        SR = SR.at[i].set(True)
        def body_fun(carry):
            cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, it= carry
            j = remaining[it]
            r = minVal + cost[i, j] - u[i] - v[j]
            path = lax.cond(
                r < shortestPathCosts[j],
                lambda: path.at[j].set(i),
                lambda: path
            )
            shortestPathCosts = shortestPathCosts.at[j].min(r)

            index = lax.cond(
                (shortestPathCosts[j] < lowest) | 
                ((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
                lambda: it,
                lambda: index
            )
            lowest = jnp.minimum(lowest, shortestPathCosts[j])
            it+=1
            return (cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, it)
        carry=lax.while_loop(lambda x: x[-1]<num_remaining,body_fun,(cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, 0))
        cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, _=carry
        minVal = lowest
        def True_fun(carry):
            remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
            sink = -1
            break_cond=True
            return (remaining,index,row4col,sink,i,SC,num_remaining,break_cond)
        def False_fun(carry):
            remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
            j = remaining[index]

            pred = row4col[j] == -1
            sink = lax.cond(pred, lambda: j, lambda: sink)
            i = lax.cond(~pred, lambda: row4col[j], lambda: i)

            SC = SC.at[j].set(True)
            num_remaining -= 1
            remaining = remaining.at[index].set(remaining[num_remaining])
            return (remaining,index,row4col,sink,i,SC,num_remaining,break_cond)
        carry=lax.cond(minVal == jnp.inf,True_fun,False_fun,(remaining,index,row4col,sink,i,SC,num_remaining,break_cond))
        remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
        return (sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining)
    carry=lax.while_loop(cond_fun,while_loop_body,(sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining))
    sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
    return sink, minVal, remaining, SR, SC, shortestPathCosts, path
@jax.jit
def solve(cost):
    '''
    Solves the linear sum assignment problem using the Hungarian algorithm.
    adapted from https://github.com/google/jax/issues/10403
    Parameters:
    - cost (ndarray): The cost matrix representing the assignment problem.

    Returns:
    - row_ind (ndarray): The row indices of the assigned elements.
    - col_ind (ndarray): The column indices of the assigned elements.
    '''

    # transpose = cost.shape[1] < cost.shape[0]
    # if transpose:#判断矩阵是否需要转置,对于方阵不需要
    #     cost = cost.T

    u = jnp.full(cost.shape[0], 0.)
    v = jnp.full(cost.shape[1], 0.)
    path = jnp.full(cost.shape[1], -1)
    col4row = jnp.full(cost.shape[0], -1)
    row4col = jnp.full(cost.shape[1], -1)
    def loop_body(carry,curRow):
        u,v,path,col4row,row4col,cost=carry
        j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)

        u = u.at[curRow].add(minVal)

        mask = SR & (jnp.arange(cost.shape[0]) != curRow)
        u=jnp.where(mask,u+minVal-shortestPathCosts[col4row],u)
        #u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])
        v=jnp.where(SC,v+shortestPathCosts-minVal,v)
        #v = v.at[SC].add(shortestPathCosts[SC] - minVal)
        def while_loop_body(carry):
            path, j, row4col, col4row, break_cond=carry

            i = path[j]

            row4col = row4col.at[j].set(i)

            col4row, j = col4row.at[i].set(j), col4row[i]

            break_cond=~(i==curRow)
            return (path,j, row4col, col4row, break_cond)
        carry=lax.while_loop(lambda x: x[-1],while_loop_body,(path,j, row4col, col4row, True))
        path,j, row4col, col4row, break_cond=carry
        return (u,v,path,col4row,row4col,cost),curRow
    carry,_=lax.scan(loop_body,(u,v,path,col4row,row4col,cost),jnp.arange(cost.shape[0]))
    u,v,path,col4row,row4col,cost=carry
    return jnp.arange(cost.shape[0]), col4row
    # if transpose:
    #     v = col4row.argsort()
    #     return col4row[v], v
    # else:
    #     return jnp.arange(cost.shape[0]), col4row

def main():
    key = random.PRNGKey(0)
    for t in count():
        key, subkey = random.split(key)
        shape = random.randint(subkey, [2], 0, 6)

        key, subkey = random.split(key)
        cost = random.uniform(subkey, (5,5))
        if t < 0: # skip to failing case
            continue
        st=time.perf_counter()
        row_ind_1, col_ind_1 = linear_sum_assignment(cost)
        end=time.perf_counter()
        print('scipy time =',end-st)
        st=time.perf_counter()
        row_ind_2, col_ind_2 = solve(cost)
        end=time.perf_counter()
        print('jax time =',end-st)
        print('{:5} {}'.format(t,
            (row_ind_1 == row_ind_2).all() and 
            (col_ind_1 == col_ind_2).all()
        ))
        if ~((row_ind_1 == row_ind_2).all() and (col_ind_1 == col_ind_2).all()):
            break

if __name__ == '__main__':
    main()

@carlosgmartin
Copy link
Contributor Author

@CoastEgo I believe @odneill was working on something along these lines, too.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Aug 18, 2024

Given this comment, let's move further discussion to google-deepmind/optax#954.

@carlosgmartin carlosgmartin closed this as not planned Won't fix, can't repro, duplicate, stale Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants