Skip to content

Commit

Permalink
feat: Add multiple constructors for distribution_ranges
Browse files Browse the repository at this point in the history
Add a second constructor that just takes `x`, `y` & `z` `_range` in case
`uniform_range` is not used (such as in `generate_noise_filter`).
  • Loading branch information
TomSchammo committed Oct 30, 2024
1 parent e7ace68 commit e9b7022
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
4 changes: 3 additions & 1 deletion cpp/include/transformation_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::class_<cpp_utils::distribution_ranges<float>>(m,
"DistributionRanges")
.def(pybind11::init<cpp_utils::range<float>, cpp_utils::range<float>,
cpp_utils::range<float>, cpp_utils::range<float>>());
cpp_utils::range<float>, cpp_utils::range<float>>())
.def(pybind11::init<cpp_utils::range<float>, cpp_utils::range<float>,
cpp_utils::range<float>>());
}
8 changes: 8 additions & 0 deletions cpp/include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ template <typename T> struct range {

template <typename T> struct distribution_ranges {
range<T> x_range, y_range, z_range, uniform_range;

distribution_ranges<T>(range<T> x_range, range<T> y_range, range<T> z_range,
range<T> uniform_range)
: x_range(x_range), y_range(y_range), z_range(z_range),
uniform_range(uniform_range){};
distribution_ranges<T>(range<T> x_range, range<T> y_range, range<T> z_range)
: x_range(x_range), y_range(y_range), z_range(z_range),
uniform_range(0, 0){};
};

/**
Expand Down
8 changes: 7 additions & 1 deletion src/LidarAug/transformations.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Tuple
from typing import Tuple, overload
from torch import Tensor


Expand Down Expand Up @@ -32,11 +32,17 @@ class DistributionRanges:
z_range: DistributionRange
uniform_range: DistributionRange

@overload
def __init__(self, x_range: DistributionRange, y_range: DistributionRange,
z_range: DistributionRange,
uniform_range: DistributionRange) -> None:
...

@overload
def __init__(self, x_range: DistributionRange, y_range: DistributionRange,
z_range: DistributionRange) -> None:
...


def translate(points: Tensor, translation: Tensor) -> None:
"""
Expand Down

0 comments on commit e9b7022

Please sign in to comment.