diff --git a/evo/core/filters.py b/evo/core/filters.py index ef2777c4..0db1bda5 100644 --- a/evo/core/filters.py +++ b/evo/core/filters.py @@ -158,3 +158,49 @@ def filter_pairs_by_angle(poses: typing.Sequence[np.ndarray], delta: float, accumulated_delta = 0.0 current_start_index = end_index return id_pairs + + +def filter_by_motion(poses: typing.Sequence[np.ndarray], + distance_threshold: float, angle_threshold: float, + degrees: bool = False): + """ + Filters a list of SE(3) poses by their motion if either the + distance or rotation angle is exceeded. + :param poses: list of SE(3) poses + :param distance_threshold: the distance threshold in meters + :param angle_threshold: the angle threshold in radians + (or degrees if degrees=True) + :param degrees: set to True if angle_threshold is in degrees + :return: list of indices of the filtered poses + """ + if len(poses) < 2: + raise FilterException("poses must contain at least two poses") + if distance_threshold < 0.0: + raise FilterException("distance threshold must be >= 0.0") + if angle_threshold < 0.0: + raise FilterException("angle threshold must be >= 0.0") + if degrees: + angle_threshold = np.deg2rad(angle_threshold) + + positions = np.array([pose[:3, 3] for pose in poses]) + distances = geometry.accumulated_distances(positions) + previous_angle_id = 0 + previous_distance = 0. + + filtered_ids = [0] + for i in range(1, len(poses)): + if distances[i] - previous_distance >= distance_threshold: + filtered_ids.append(i) + previous_angle_id = i + previous_distance = distances[i] + continue + current_angle = lie.so3_log_angle( + lie.relative_so3(poses[previous_angle_id][:3, :3], + poses[i][:3, :3])) + if current_angle >= angle_threshold: + filtered_ids.append(i) + previous_angle_id = i + previous_distance = distances[i] + continue + + return filtered_ids diff --git a/evo/core/trajectory.py b/evo/core/trajectory.py index 43672094..2d97064b 100644 --- a/evo/core/trajectory.py +++ b/evo/core/trajectory.py @@ -29,6 +29,7 @@ import evo.core.transformations as tr import evo.core.geometry as geometry from evo.core import lie_algebra as lie +from evo.core import filters logger = logging.getLogger(__name__) @@ -298,6 +299,35 @@ def reduce_to_ids( if hasattr(self, "_poses_se3"): self._poses_se3 = [self._poses_se3[idx] for idx in ids] + def downsample(self, num_poses: int) -> None: + """ + Downsample the trajectory to the specified number of poses + with a simple evenly spaced sampling. + Does nothing if the trajectory already has less or equal poses. + :param num_poses: number of poses to keep + """ + if self.num_poses <= num_poses: + return + if self.num_poses < 2 or num_poses < 2: + raise TrajectoryException("can't downsample to less than 2 poses") + ids = np.linspace(0, self.num_poses - 1, num_poses, dtype=int) + self.reduce_to_ids(ids) + + def motion_filter(self, distance_threshold: float, angle_threshold: float, + degrees: bool = False) -> None: + """ + Filters the trajectory by its motion if either the accumulated distance + or rotation angle is exceeded. + :param distance_threshold: the distance threshold in meters + :param angle_threshold: the angle threshold in radians + (or degrees if degrees=True) + :param degrees: set to True if angle_threshold is in degrees + """ + filtered_ids = filters.filter_by_motion(self.poses_se3, + distance_threshold, + angle_threshold, degrees) + self.reduce_to_ids(filtered_ids) + def check(self) -> typing.Tuple[bool, dict]: """ checks if the data is valid diff --git a/evo/main_traj.py b/evo/main_traj.py index 45d730f9..8acefc25 100755 --- a/evo/main_traj.py +++ b/evo/main_traj.py @@ -193,6 +193,27 @@ def run(args): trajectories, ref_traj = load_trajectories(args) + if args.downsample: + logger.debug(SEP) + logger.info("Downsampling trajectories to max %s poses.", + args.downsample) + for traj in trajectories.values(): + traj.downsample(args.downsample) + if ref_traj: + ref_traj.downsample(args.downsample) + + if args.motion_filter: + logger.debug(SEP) + distance_threshold = args.motion_filter[0] + angle_threshold = args.motion_filter[1] + logger.info( + "Filtering trajectories with motion filter " + "thresholds: %f m, %f deg", distance_threshold, angle_threshold) + for traj in trajectories.values(): + traj.motion_filter(distance_threshold, angle_threshold, True) + if ref_traj: + ref_traj.motion_filter(distance_threshold, angle_threshold, True) + if args.merge: if args.subcommand == "kitti": die("Can't merge KITTI files.") diff --git a/evo/main_traj_parser.py b/evo/main_traj_parser.py index cf66a26f..36f39e44 100644 --- a/evo/main_traj_parser.py +++ b/evo/main_traj_parser.py @@ -60,6 +60,14 @@ def parser() -> argparse.ArgumentParser: "--project_to_plane", type=str, choices=["xy", "xz", "yz"], help="Projects the trajectories to 2D in the desired plane. " "This is done after potential 3D alignment & transformation steps.") + algo_opts.add_argument("--downsample", type=int, + help="Downsample trajectories to max N poses.") + algo_opts.add_argument( + "--motion_filter", type=float, nargs=2, + metavar=("DISTANCE", "ANGLE_DEGREES"), + help="Filters out poses if the distance or angle to the previous one " + " is below the threshold distance or angle. " + "Angle is expected in degrees.") output_opts.add_argument("-p", "--plot", help="show plot window", action="store_true") output_opts.add_argument( diff --git a/test/test_filters.py b/test/test_filters.py index c6749253..73c6d900 100755 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -166,5 +166,21 @@ def test_poses6_all_pairs(self): self.assertEqual(id_pairs, expected_result) +class TestFilterByMotion(unittest.TestCase): + def test_angle_threshold_only(self): + poses = POSES_5 + angle_threshold = math.pi + expected_result = [0, 1, 2, 4] + filtered_ids = filters.filter_by_motion(poses, 999, angle_threshold) + self.assertEqual(filtered_ids, expected_result) + + def test_distance_threshold_only(self): + poses = POSES_2 + distance_threshold = 0.5 + expected_result = [0, 1, 3] + filtered_ids = filters.filter_by_motion(poses, distance_threshold, 99) + self.assertEqual(filtered_ids, expected_result) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_trajectory.py b/test/test_trajectory.py index 72ca5afd..d023603b 100755 --- a/test/test_trajectory.py +++ b/test/test_trajectory.py @@ -99,6 +99,18 @@ def test_reduce_to_ids(self): len_reduced = path_reduced.path_length self.assertAlmostEqual(len_initial_segment, len_reduced) + def test_downsample(self): + path = helpers.fake_path(100) + path_downsampled = copy.deepcopy(path) + path_downsampled.downsample(10) + self.assertEqual(path_downsampled.num_poses, 10) + self.assertTrue( + np.equal(path.positions_xyz[0], + path_downsampled.positions_xyz[0]).all()) + self.assertTrue( + np.equal(path.positions_xyz[-1], + path_downsampled.positions_xyz[-1]).all()) + def test_transform(self): path = helpers.fake_path(10) path_transformed = copy.deepcopy(path)