From e9b702246664082c80fd2cd401c66a7e159a3507 Mon Sep 17 00:00:00 2001 From: Tom Schammo Date: Wed, 30 Oct 2024 13:17:26 +0100 Subject: [PATCH] feat: Add multiple constructors for `distribution_ranges` Add a second constructor that just takes `x`, `y` & `z` `_range` in case `uniform_range` is not used (such as in `generate_noise_filter`). --- cpp/include/transformation_bindings.hpp | 4 +++- cpp/include/utils.hpp | 8 ++++++++ src/LidarAug/transformations.pyi | 8 +++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/cpp/include/transformation_bindings.hpp b/cpp/include/transformation_bindings.hpp index 031ea54..1d1b78c 100644 --- a/cpp/include/transformation_bindings.hpp +++ b/cpp/include/transformation_bindings.hpp @@ -63,5 +63,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::class_>(m, "DistributionRanges") .def(pybind11::init, cpp_utils::range, - cpp_utils::range, cpp_utils::range>()); + cpp_utils::range, cpp_utils::range>()) + .def(pybind11::init, cpp_utils::range, + cpp_utils::range>()); } diff --git a/cpp/include/utils.hpp b/cpp/include/utils.hpp index b46866b..84904d3 100644 --- a/cpp/include/utils.hpp +++ b/cpp/include/utils.hpp @@ -330,6 +330,14 @@ template struct range { template struct distribution_ranges { range x_range, y_range, z_range, uniform_range; + + distribution_ranges(range x_range, range y_range, range z_range, + range uniform_range) + : x_range(x_range), y_range(y_range), z_range(z_range), + uniform_range(uniform_range){}; + distribution_ranges(range x_range, range y_range, range z_range) + : x_range(x_range), y_range(y_range), z_range(z_range), + uniform_range(0, 0){}; }; /** diff --git a/src/LidarAug/transformations.pyi b/src/LidarAug/transformations.pyi index 16bd470..268083a 100644 --- a/src/LidarAug/transformations.pyi +++ b/src/LidarAug/transformations.pyi @@ -1,5 +1,5 @@ from enum import Enum -from typing import Tuple +from typing import Tuple, overload from torch import Tensor @@ -32,11 +32,17 @@ class DistributionRanges: z_range: DistributionRange uniform_range: DistributionRange + @overload def __init__(self, x_range: DistributionRange, y_range: DistributionRange, z_range: DistributionRange, uniform_range: DistributionRange) -> None: ... + @overload + def __init__(self, x_range: DistributionRange, y_range: DistributionRange, + z_range: DistributionRange) -> None: + ... + def translate(points: Tensor, translation: Tensor) -> None: """