-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement scipy.optimize.linear_sum_assignment
#10403
Comments
@hawkinsp I can take this up. Can you please provide bit more context where to add the function and any reference if possible? |
Anyone working on this? It would be nice to have this in JAX. For reference there is this paper On implementing 2D rectangular assignment algorithms mentioned in the SciPy documentation, that reviews many related algorithms, and a more recent paper, A Fast Scalable Solver for the Dense Linear (Sum) Assignment Problem that attempts to parallelise the algorithm. |
@avinashsai @riversdark Here is scipy's C++ implementation. I ported it to JAX, though it's not in fully JITable form yet: from itertools import count
from jax import numpy as jnp, random, jit
from jax.lax import cond, while_loop
from scipy.optimize import linear_sum_assignment
def augmenting_path(cost, u, v, path, row4col, i):
minVal = 0
num_remaining = cost.shape[1]
remaining = jnp.arange(cost.shape[1])[::-1]
SR = jnp.full(cost.shape[0], False)
SC = jnp.full(cost.shape[1], False)
shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)
sink = -1
while sink == -1:
index = -1
lowest = jnp.inf
SR = SR.at[i].set(True)
for it in range(num_remaining):
j = remaining[it]
r = minVal + cost[i, j] - u[i] - v[j]
path = cond(
r < shortestPathCosts[j],
lambda: path.at[j].set(i),
lambda: path
)
shortestPathCosts = shortestPathCosts.at[j].min(r)
index = cond(
(shortestPathCosts[j] < lowest) |
((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
lambda: it,
lambda: index
)
lowest = jnp.minimum(lowest, shortestPathCosts[j])
minVal = lowest
if minVal == jnp.inf: # infeasible cost matrix
sink = -1
break
j = remaining[index]
pred = row4col[j] == -1
sink = cond(pred, lambda: j, lambda: sink)
i = cond(~pred, lambda: row4col[j], lambda: i)
SC = SC.at[j].set(True)
num_remaining -= 1
remaining = remaining.at[index].set(remaining[num_remaining])
return sink, minVal, remaining, SR, SC, shortestPathCosts, path
def solve(cost):
transpose = cost.shape[1] < cost.shape[0]
if transpose:
cost = cost.T
u = jnp.full(cost.shape[0], 0.)
v = jnp.full(cost.shape[1], 0.)
path = jnp.full(cost.shape[1], -1)
col4row = jnp.full(cost.shape[0], -1)
row4col = jnp.full(cost.shape[1], -1)
for curRow in range(cost.shape[0]):
j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)
u = u.at[curRow].add(minVal)
mask = SR & (jnp.arange(cost.shape[0]) != curRow)
u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])
v = v.at[SC].add(shortestPathCosts[SC] - minVal)
while True:
i = path[j]
row4col = row4col.at[j].set(i)
col4row, j = col4row.at[i].set(j), col4row[i]
if i == curRow:
break
if transpose:
v = col4row.argsort()
return col4row[v], v
else:
return jnp.arange(cost.shape[0]), col4row
def main():
key = random.PRNGKey(0)
for t in count():
key, subkey = random.split(key)
shape = random.randint(subkey, [2], 0, 6)
key, subkey = random.split(key)
cost = random.uniform(subkey, shape)
if t < 0: # skip to failing case
continue
row_ind_1, col_ind_1 = linear_sum_assignment(cost)
row_ind_2, col_ind_2 = solve(cost)
print('{:5} {}'.format(t,
(row_ind_1 == row_ind_2).all() and
(col_ind_1 == col_ind_2).all()
))
if __name__ == '__main__':
main() |
Any updates on this? This seems particularly important to have for set to set machine learning methods (eg detr). |
@avinashsai Are you still interested in implementing this? |
In case you need it, I modified the code by @carlosgmartin to make this JITable. I use this for a (5,5) square cost matrix so I comment out the transpose operation. from itertools import count
from jax import numpy as jnp, random, jit
from jax import lax
import jax
from scipy.optimize import linear_sum_assignment
import time
@jax.jit
def augmenting_path(cost, u, v, path, row4col, i):
minVal = 0
remaining = jnp.arange(cost.shape[1])[::-1]
num_remaining = cost.shape[1]
SR = jnp.full(cost.shape[0], False)
SC = jnp.full(cost.shape[1], False)
shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)
sink = -1
break_cond=False
def cond_fun(carry):
sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
return (sink == -1)&(~break_cond)
def while_loop_body(carry):
sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
index = -1
lowest = jnp.inf
SR = SR.at[i].set(True)
def body_fun(carry):
cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, it= carry
j = remaining[it]
r = minVal + cost[i, j] - u[i] - v[j]
path = lax.cond(
r < shortestPathCosts[j],
lambda: path.at[j].set(i),
lambda: path
)
shortestPathCosts = shortestPathCosts.at[j].min(r)
index = lax.cond(
(shortestPathCosts[j] < lowest) |
((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
lambda: it,
lambda: index
)
lowest = jnp.minimum(lowest, shortestPathCosts[j])
it+=1
return (cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, it)
carry=lax.while_loop(lambda x: x[-1]<num_remaining,body_fun,(cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, 0))
cost, u, v, path, row4col, i, remaining, minVal, shortestPathCosts, lowest, index, _=carry
minVal = lowest
def True_fun(carry):
remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
sink = -1
break_cond=True
return (remaining,index,row4col,sink,i,SC,num_remaining,break_cond)
def False_fun(carry):
remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
j = remaining[index]
pred = row4col[j] == -1
sink = lax.cond(pred, lambda: j, lambda: sink)
i = lax.cond(~pred, lambda: row4col[j], lambda: i)
SC = SC.at[j].set(True)
num_remaining -= 1
remaining = remaining.at[index].set(remaining[num_remaining])
return (remaining,index,row4col,sink,i,SC,num_remaining,break_cond)
carry=lax.cond(minVal == jnp.inf,True_fun,False_fun,(remaining,index,row4col,sink,i,SC,num_remaining,break_cond))
remaining,index,row4col,sink,i,SC,num_remaining,break_cond=carry
return (sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining)
carry=lax.while_loop(cond_fun,while_loop_body,(sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining))
sink, minVal, remaining, SR, SC, shortestPathCosts, path, break_cond, i, u, v, row4col, cost, num_remaining=carry
return sink, minVal, remaining, SR, SC, shortestPathCosts, path
@jax.jit
def solve(cost):
'''
Solves the linear sum assignment problem using the Hungarian algorithm.
adapted from https://github.com/google/jax/issues/10403
Parameters:
- cost (ndarray): The cost matrix representing the assignment problem.
Returns:
- row_ind (ndarray): The row indices of the assigned elements.
- col_ind (ndarray): The column indices of the assigned elements.
'''
# transpose = cost.shape[1] < cost.shape[0]
# if transpose:#判断矩阵是否需要转置,对于方阵不需要
# cost = cost.T
u = jnp.full(cost.shape[0], 0.)
v = jnp.full(cost.shape[1], 0.)
path = jnp.full(cost.shape[1], -1)
col4row = jnp.full(cost.shape[0], -1)
row4col = jnp.full(cost.shape[1], -1)
def loop_body(carry,curRow):
u,v,path,col4row,row4col,cost=carry
j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)
u = u.at[curRow].add(minVal)
mask = SR & (jnp.arange(cost.shape[0]) != curRow)
u=jnp.where(mask,u+minVal-shortestPathCosts[col4row],u)
#u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])
v=jnp.where(SC,v+shortestPathCosts-minVal,v)
#v = v.at[SC].add(shortestPathCosts[SC] - minVal)
def while_loop_body(carry):
path, j, row4col, col4row, break_cond=carry
i = path[j]
row4col = row4col.at[j].set(i)
col4row, j = col4row.at[i].set(j), col4row[i]
break_cond=~(i==curRow)
return (path,j, row4col, col4row, break_cond)
carry=lax.while_loop(lambda x: x[-1],while_loop_body,(path,j, row4col, col4row, True))
path,j, row4col, col4row, break_cond=carry
return (u,v,path,col4row,row4col,cost),curRow
carry,_=lax.scan(loop_body,(u,v,path,col4row,row4col,cost),jnp.arange(cost.shape[0]))
u,v,path,col4row,row4col,cost=carry
return jnp.arange(cost.shape[0]), col4row
# if transpose:
# v = col4row.argsort()
# return col4row[v], v
# else:
# return jnp.arange(cost.shape[0]), col4row
def main():
key = random.PRNGKey(0)
for t in count():
key, subkey = random.split(key)
shape = random.randint(subkey, [2], 0, 6)
key, subkey = random.split(key)
cost = random.uniform(subkey, (5,5))
if t < 0: # skip to failing case
continue
st=time.perf_counter()
row_ind_1, col_ind_1 = linear_sum_assignment(cost)
end=time.perf_counter()
print('scipy time =',end-st)
st=time.perf_counter()
row_ind_2, col_ind_2 = solve(cost)
end=time.perf_counter()
print('jax time =',end-st)
print('{:5} {}'.format(t,
(row_ind_1 == row_ind_2).all() and
(col_ind_1 == col_ind_2).all()
))
if ~((row_ind_1 == row_ind_2).all() and (col_ind_1 == col_ind_2).all()):
break
if __name__ == '__main__':
main() |
Given this comment, let's move further discussion to google-deepmind/optax#954. |
Implement
scipy.optimize.linear_sum_assignment
, which solves the assignment problem. Among other things, this is useful for estimating the Wasserstein distance between two distributions based on their empirical measures.The text was updated successfully, but these errors were encountered: