Skip to content

Commit

Permalink
Fix the number of neighbors bug in radius_graph (#228)
Browse files Browse the repository at this point in the history
* Fix the number of neighbors bug in `radius_graph`

Signed-off-by: Xuangui Huang <[email protected]>

* fix linting issue

---------

Signed-off-by: Xuangui Huang <[email protected]>
  • Loading branch information
stslxg-nv authored Sep 10, 2024
1 parent e1e788b commit d16c692
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 25 deletions.
27 changes: 18 additions & 9 deletions csrc/cpu/radius_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index) {

CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
Expand Down Expand Up @@ -54,10 +55,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
size_t num_matches = mat_index.index->radiusSearch(
y_data + i * y.size(1), r * r, ret_matches, params);

for (size_t j = 0; j < std::min(num_matches, (size_t)max_num_neighbors);
j++) {
out_vec.push_back(ret_matches[j].first);
out_vec.push_back(i);
for (size_t j = 0, count = 0;
j < num_matches && count < (size_t)max_num_neighbors;
j++) {
if (!ignore_same_index || ret_matches[j].first != i) {
out_vec.push_back(ret_matches[j].first);
out_vec.push_back(i);
count++;
}
}
}

Expand Down Expand Up @@ -91,10 +96,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
size_t num_matches = mat_index.index->radiusSearch(
y_data + i * y.size(1), r * r, ret_matches, params);

for (size_t j = 0;
j < std::min(num_matches, (size_t)max_num_neighbors); j++) {
out_vec.push_back(x_start + ret_matches[j].first);
out_vec.push_back(i);
for (size_t j = 0, count = 0;
j < num_matches && count < (size_t)max_num_neighbors;
j++) {
if (!ignore_same_index || x_start + ret_matches[j].first != i) {
out_vec.push_back(x_start + ret_matches[j].first);
out_vec.push_back(i);
count++;
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/cpu/radius_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers);
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index);
10 changes: 6 additions & 4 deletions csrc/cuda/radius_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const int64_t max_num_neighbors) {
const int64_t max_num_neighbors,
const bool ignore_same_index) {

const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
Expand All @@ -29,7 +30,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
(x[n_x * dim + d] - y[n_y * dim + d]);
}

if (dist < r) {
if (dist < r && !(ignore_same_index && n_y == n_x)) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
Expand All @@ -43,7 +44,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, const double r,
const int64_t max_num_neighbors) {
const int64_t max_num_neighbors,
const bool ignore_same_index) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
Expand Down Expand Up @@ -86,7 +88,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
ptr_x.value().data_ptr<int64_t>(),
ptr_y.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), r * r, x.size(0), y.size(0), x.size(1),
ptr_x.value().numel() - 1, max_num_neighbors);
ptr_x.value().numel() - 1, max_num_neighbors, ignore_same_index);
});

auto mask = row != -1;
Expand Down
3 changes: 2 additions & 1 deletion csrc/cuda/radius_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors);
int64_t max_num_neighbors,
bool ignore_same_index);
7 changes: 4 additions & 3 deletions csrc/radius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; }
CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors, int64_t num_workers) {
int64_t max_num_neighbors, int64_t num_workers,
bool ignore_same_index) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors, ignore_same_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers);
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers, ignore_same_index);
}
}

Expand Down
41 changes: 41 additions & 0 deletions test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])


def to_degree(edge_index):
_, counts = torch.unique(edge_index[1], return_counts=True)
return counts.tolist()


def to_batch(nodes):
return [int(i / 4) for i in nodes]


@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
Expand Down Expand Up @@ -74,6 +83,38 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
[-1, -1],
[-1, -1],
[-1, -1],
[-1, -1],
], dtype, device)

edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)

edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])
assert to_batch(edge_index[0]) == batch_x.tolist()


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
Expand Down
15 changes: 8 additions & 7 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def radius(
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
ignore_same_index: bool = False
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Expand Down Expand Up @@ -40,6 +41,9 @@ def radius(
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ignore_same_index (bool, optional): If :obj:`True`, each element in
:obj:`y` ignores the point in :obj:`x` with the same index.
(default: :obj:`False`)
.. code-block:: python
Expand Down Expand Up @@ -80,7 +84,8 @@ def radius(
ptr_y = torch.bucketize(arange, batch_y)

return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers)
max_num_neighbors, num_workers,
ignore_same_index)


def radius_graph(
Expand Down Expand Up @@ -133,15 +138,11 @@ def radius_graph(

assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers, batch_size)
max_num_neighbors,
num_workers, batch_size, not loop)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
row, col = edge_index[0], edge_index[1]

if not loop:
mask = row != col
row, col = row[mask], col[mask]

return torch.stack([row, col], dim=0)

0 comments on commit d16c692

Please sign in to comment.