Skip to content

Commit

Permalink
fix neighbor sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 17, 2020
1 parent 69fada5 commit 4116005
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
13 changes: 7 additions & 6 deletions csrc/cpu/sampler_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
auto num_neighbors = row_end - row_start;

int64_t size = count;
if (count < 1) {
if (count < 1)
size = int64_t(ceil(factor * float(num_neighbors)));
}
if (size > num_neighbors)
size = num_neighbors;

// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
Expand All @@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
std::unordered_set<int64_t> set;
if (size < 0.7 * float(num_neighbors)) {
while (int64_t(set.size()) < size) {
int64_t sample = (rand() % num_neighbors) + row_start;
set.insert(sample);
int64_t sample = rand() % num_neighbors;
set.insert(sample + row_start);
}
std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end());
} else {
auto sample = at::randperm(num_neighbors, start.options()) + row_start;
auto sample = torch::randperm(num_neighbors, start.options());
auto sample_data = sample.data_ptr<int64_t>();
for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_data[j]);
e_ids.push_back(sample_data[j] + row_start);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_extensions():

setup(
name='torch_cluster',
version='1.5.1',
version='1.5.2',
author='Matthias Fey',
author_email='[email protected]',
url='https://github.com/rusty1s/pytorch_cluster',
Expand Down
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

__version__ = '1.5.1'
__version__ = '1.5.2'
expected_torch_version = (1, 4)

try:
Expand Down

0 comments on commit 4116005

Please sign in to comment.