Skip to content

Commit

Permalink
flow arg for radius
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 4, 2019
1 parent 4047c05 commit e3c3b13
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
]

__version__ = '1.4.1'
__version__ = '1.4.2'
url = 'https://github.com/rusty1s/pytorch_cluster'

install_requires = ['scipy']
Expand Down
2 changes: 0 additions & 2 deletions test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ def test_knn_graph(dtype, device):

row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)

assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)

assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
12 changes: 9 additions & 3 deletions test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def test_radius_graph(dtype, device):
[+1, -1],
], dtype, device)

out = radius_graph(x, r=2)
assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]]
row, col = radius_graph(x, r=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

row, col = radius_graph(x, r=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .sampler import neighbor_sampler
from .rw import random_walk

__version__ = '1.4.1'
__version__ = '1.4.2'

__all__ = [
'graclus_cluster',
Expand Down
19 changes: 13 additions & 6 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return torch.stack([row[mask], col[mask]], dim=0)


def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
def radius_graph(x,
r,
batch=None,
loop=False,
max_num_neighbors=32,
flow='source_to_target'):
r"""Computes graph edges to all points within a given distance.
Args:
Expand All @@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor`
Expand All @@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
"""

edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = edge_index
assert flow in ['source_to_target', 'target_to_source']
row, col = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
return edge_index
return torch.stack([row, col], dim=0)

0 comments on commit e3c3b13

Please sign in to comment.