Skip to content

Commit

Permalink
Add num_points to fps as an alternative to ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsHaalck authored and Lars Haalck committed Jun 24, 2024
1 parent 616704a commit 00af9dd
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 30 deletions.
2 changes: 1 addition & 1 deletion csrc/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace cluster

CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
int64_t num_points, bool random_start);

CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
Expand Down
17 changes: 13 additions & 4 deletions csrc/cpu/fps_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
}

torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
torch::Tensor num_points, bool random_start) {

CHECK_CPU(src);
CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_CPU(num_points);
CHECK_INPUT(ptr.dim() == 1);

src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.numel() - 1;

auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);

torch::Tensor out_ptr;
if (num_points.sum().item<int64_t>() == 0) {
out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
} else {
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
"Passed tensor has fewer elements than requested number of returned points.")
out_ptr = deg.toType(torch::kLong)
.minimum(num_points.toType(torch::kLong))
.cumsum(0);
}
auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());

auto ptr_data = ptr.data_ptr<int64_t>();
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/fps_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
#include "../extensions.h"

torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start);
torch::Tensor num_points, bool random_start);
18 changes: 15 additions & 3 deletions csrc/cuda/fps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}

torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
torch::Tensor ratio, torch::Tensor num_points,
bool random_start) {

CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_CUDA(num_points);
CHECK_INPUT(ptr.dim() == 1);
c10::cuda::MaybeSetDevice(src.get_device());

Expand All @@ -78,8 +80,18 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto batch_size = ptr.numel() - 1;

auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
torch::Tensor out_ptr;
if (num_points.sum().item<int64_t>() == 0) {
out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
} else {
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
"Passed tensor has fewer elements than requested number of returned points.")
out_ptr = deg.toType(torch::kLong)
.minimum(num_points.toType(torch::kLong))
.cumsum(0);
}

out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0);

torch::Tensor start;
Expand Down
3 changes: 2 additions & 1 deletion csrc/cuda/fps_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
#include "../extensions.h"

torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start);
torch::Tensor ratio, torch::Tensor num_points,
bool random_start);
9 changes: 5 additions & 4 deletions csrc/fps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; }
#endif
#endif

CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, torch::Tensor num_points,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return fps_cuda(src, ptr, ratio, random_start);
return fps_cuda(src, ptr, ratio, num_points, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, ptr, ratio, random_start);
return fps_cpu(src, ptr, ratio, num_points, random_start);
}
}

Expand Down
25 changes: 23 additions & 2 deletions test/test_fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
return fps(x, None, ratio, None, False)


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
Expand All @@ -33,26 +33,36 @@ def test_fps(dtype, device):

out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, num_points=2, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, batch, num_points=4, random_start=False)
assert out.tolist() == [0, 2, 1, 3, 4, 6, 5, 7]

ratio = torch.tensor(0.5, device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

ratio = torch.tensor([0.5, 0.5], device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

num = torch.tensor([2, 2], device=device)
out = fps(x, batch, num_points=num, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]

out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, num_points=4, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]

out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
Expand All @@ -63,6 +73,17 @@ def test_fps(dtype, device):
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]

# requesting too many points
with pytest.raises(RuntimeError):
out = fps(x, batch, num_points=100, random_start=False)

with pytest.raises(RuntimeError):
out = fps(x, batch, num_points=5, random_start=False)

# invalid argument combination
with pytest.raises(ValueError):
out = fps(x, batch, ratio=0.0, num_points=0, random_start=False)


@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
Expand Down
78 changes: 64 additions & 14 deletions torch_cluster/fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,57 @@


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover

@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
num_points: Optional[Union[Tensor, int]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
Expand All @@ -50,7 +74,11 @@ def fps( # noqa
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (float or Tensor, optional): Sampling ratio.
Only ratio or num_points can be specified.
(default: :obj:`0.5`)
num_points (int, optional): Number of returned points.
Only ratio or num_points can be specified.
(default: :obj:`None`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Expand All @@ -71,25 +99,47 @@ def fps( # noqa
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
"""
# check if only of of ratio or num_points is set
# if no one is set, fallback to ratio = 0.5
if ratio is not None and num_points is not None:
raise ValueError("Only one of ratio and num_points can be specified.")

r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
if num_points is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
else:
r = torch.tensor(0.0, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None

num: Optional[Tensor] = None
if num_points is None:
num = torch.tensor(0, dtype=torch.long, device=src.device)
elif isinstance(num_points, int):
num = torch.tensor(num_points, dtype=torch.long, device=src.device)
else:
num = num_points

assert r is not None and num is not None

if r.sum() == 0 and num.sum() == 0:
raise ValueError("At least one of ratio or num_points should be > 0")

if ptr is not None:
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
return torch.ops.torch_cluster.fps_ptr_list(
src, ptr, r, random_start)
src, ptr, r, random_start
)

if isinstance(ptr, list):
return torch.ops.torch_cluster.fps(
src, torch.tensor(ptr, device=src.device), r, random_start)
src, torch.tensor(ptr, device=src.device), r, num, random_start
)
else:
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
return torch.ops.torch_cluster.fps(src, ptr, r, num, random_start)

if batch is not None:
assert src.size(0) == batch.numel()
Expand All @@ -104,4 +154,4 @@ def fps( # noqa
else:
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)

return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
return torch.ops.torch_cluster.fps(src, ptr_vec, r, num, random_start)

0 comments on commit 00af9dd

Please sign in to comment.