From e4f60cdbab75bbd20d90efa48601e124f75fc59b Mon Sep 17 00:00:00 2001 From: Tom Schammo Date: Thu, 4 Jan 2024 16:36:22 +0100 Subject: [PATCH] test(thin_out): Make `thin_out` test a controlled RNG test Fixes issue #4, deals with part of issue #2. --- LidarAug/cpp/test/test.cpp | 46 +++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/LidarAug/cpp/test/test.cpp b/LidarAug/cpp/test/test.cpp index a3e9b3c..0ac1238 100644 --- a/LidarAug/cpp/test/test.cpp +++ b/LidarAug/cpp/test/test.cpp @@ -113,21 +113,6 @@ TEST(DrawUniformValuesTest, BasicAssertions) { } } -TEST(ThinOutTest, BasicAssertions) { - auto points = torch::rand({2, 10, 4}); - dimensions dims_original = {points.size(0), points.size(1), points.size(2)}; - auto new_points = thin_out(points, 1); - dimensions dims_edited = {new_points.size(0), new_points.size(1), - new_points.size(2)}; - - EXPECT_EQ(dims_edited.batch_size, dims_original.batch_size); - EXPECT_LT(dims_edited.num_items, dims_original.num_items) - << "Expected the amounts to have reduced...\noriginal:\n" - << points << "\nnew:\n" - << new_points; - EXPECT_EQ(dims_edited.num_features, dims_original.num_features); -} - TEST(RandomNoiseTest, BasicAssertions) { auto points = torch::tensor({{{1.0, 2.0, 3.0, 10.9}, {4.0, 5.0, 6.0, -10.0}}}); @@ -499,6 +484,37 @@ TEST(RotateRandomTest, BasicAssertions) { << "\nexpected_labels:\n" << expected_labels; } + +TEST(ThinOutTest, BasicAssertions) { + constexpr tensor_size_t BATCHES = 2; + constexpr tensor_size_t ITEMS = 10; + + const auto points = torch::rand({BATCHES, ITEMS, 4}, torch::kF32); + auto new_points = thin_out(points, 1); + + // NOTE(tom): percent = 0.653750002, indices = {9, 4, 8, 0} + const auto indices = torch::tensor({9, 4, 8, 0}); + + const auto expected_points = points.index_select(1, indices); + + EXPECT_EQ(BATCHES, points.size(0)) << "`points` was not supposed to change!"; + EXPECT_EQ(ITEMS, points.size(1)) << "`points` was not supposed to change!"; + + EXPECT_EQ(BATCHES, expected_points.size(0)) + << "Batch dimensions were not supposed to change!"; + EXPECT_EQ(BATCHES, new_points.size(0)) + << "Batch dimensions were not supposed to change!"; + + EXPECT_EQ(new_points.size(1), expected_points.size(1)) + << "Number of points does not match!"; + + EXPECT_TRUE(new_points.equal(expected_points)) + << "Thin out not as expected:\noriginal:\n" + << points << "\nexpected:\n" + << expected_points << "\nactual:\n" + << new_points; +} + #endif // NOLINTEND