diff --git a/csrc/cpu/sampler_cpu.cpp b/csrc/cpu/sampler_cpu.cpp index 90c23d8e..5b3f68c6 100644 --- a/csrc/cpu/sampler_cpu.cpp +++ b/csrc/cpu/sampler_cpu.cpp @@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr, auto num_neighbors = row_end - row_start; int64_t size = count; - if (count < 1) { + if (count < 1) size = int64_t(ceil(factor * float(num_neighbors))); - } + if (size > num_neighbors) + size = num_neighbors; // If the number of neighbors is approximately equal to the number of // neighbors which are requested, we use `randperm` to sample without @@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr, std::unordered_set set; if (size < 0.7 * float(num_neighbors)) { while (int64_t(set.size()) < size) { - int64_t sample = (rand() % num_neighbors) + row_start; - set.insert(sample); + int64_t sample = rand() % num_neighbors; + set.insert(sample + row_start); } std::vector v(set.begin(), set.end()); e_ids.insert(e_ids.end(), v.begin(), v.end()); } else { - auto sample = at::randperm(num_neighbors, start.options()) + row_start; + auto sample = torch::randperm(num_neighbors, start.options()); auto sample_data = sample.data_ptr(); for (auto j = 0; j < size; j++) { - e_ids.push_back(sample_data[j]); + e_ids.push_back(sample_data[j] + row_start); } } } diff --git a/setup.py b/setup.py index 22a5272b..5666ba05 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def get_extensions(): setup( name='torch_cluster', - version='1.5.1', + version='1.5.2', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url='https://github.com/rusty1s/pytorch_cluster', diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index 54e4076c..f5166e67 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -3,7 +3,7 @@ import torch -__version__ = '1.5.1' +__version__ = '1.5.2' expected_torch_version = (1, 4) try: