Skip to content

Commit

Permalink
add flow to knn call
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 4, 2019
1 parent 5221414 commit fefd2cb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 7 additions & 1 deletion test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ def test_knn_graph(dtype, device):
[+1, -1],
], dtype, device)

row, col = knn_graph(x, k=2)
row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)

assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)

assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
13 changes: 8 additions & 5 deletions torch_cluster/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return torch.stack([row, col], dim=0)


def knn_graph(x, k, batch=None, loop=False):
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
Expand All @@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
node to a specific example. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor`
Expand All @@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
"""

edge_index = knn(x, x, k if loop else k + 1, batch, batch)
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
return edge_index
return torch.stack([row, col], dim=0)

0 comments on commit fefd2cb

Please sign in to comment.