Skip to content

Commit

Permalink
add TIERIV dataset reader and custom dataset nuscenes eval metric
Browse files Browse the repository at this point in the history
Signed-off-by: Kaan Çolak <[email protected]>
  • Loading branch information
kaancolak committed Nov 30, 2023
1 parent be583bc commit 5bd5365
Show file tree
Hide file tree
Showing 12 changed files with 2,938 additions and 2 deletions.
602 changes: 602 additions & 0 deletions configs/centerpoint/centerpoint_custom_test.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
RandomShiftScale, Resize3D, VoxelBasedPointSampler)
from .utils import get_loading_pipeline
from .waymo_dataset import WaymoDataset
from .tier4_dataset import Tier4Dataset

__all__ = [
'KittiDataset', 'CBGSDataset', 'NuScenesDataset', 'LyftDataset',
'KittiDataset', 'CBGSDataset', 'NuScenesDataset', 'LyftDataset', 'Tier4Dataset',
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter',
'LoadPointsFromFile', 'S3DISSegDataset', 'S3DISDataset',
Expand Down
115 changes: 115 additions & 0 deletions mmdet3d/datasets/tier4_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp
import os
from typing import Callable, List, Union

import numpy as np

from mmdet3d.registry import DATASETS
from mmdet3d.structures import LiDARInstance3DBoxes
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
# from .det3d_dataset import Det3DDataset
from .nuscenes_dataset import NuScenesDataset


@DATASETS.register_module()
class Tier4Dataset(NuScenesDataset):
METAINFO = {
'classes': ('car', 'truck', 'bus', 'bicycle', 'pedestrian'),
'version': 'v1.0-trainval',
'palette': [
(255, 158, 0), # Orange
(255, 99, 71), # Tomato
(255, 140, 0), # Darkorange
(255, 127, 80), # Coral
(233, 150, 70), # Darksalmon
]
}

def __init__(self,
box_type_3d: str = 'LiDAR',
load_type: str = 'frame_based',
with_velocity: bool = True,
use_valid_flag: bool = False,
**kwargs,) -> None:

self.use_valid_flag = use_valid_flag
self.with_velocity = with_velocity

# TODO: Redesign multi-view data process in the future
assert load_type in ('frame_based', 'mv_image_based',
'fov_image_based')
self.load_type = load_type

assert box_type_3d.lower() in ('lidar', 'camera')
super().__init__(**kwargs)

def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path. And process the `instances` field to
`ann_info` in training stage.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if self.load_type == 'mv_image_based':
info = super().parse_data_info(info)
else:
if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])

info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path']
if 'lidar_sweeps' in info:
for sweep in info['lidar_sweeps']:
file_suffix_splitted = sweep['lidar_points']['lidar_path'].split(os.sep)
file_suffix = os.sep.join(file_suffix_splitted[-4:])
if 'samples' in sweep['lidar_points']['lidar_path']:
sweep['lidar_points']['lidar_path'] = osp.join(
self.data_prefix['pts'], file_suffix)
else:
sweep['lidar_points']['lidar_path'] = info['lidar_points']['lidar_path']

if self.modality['use_camera']:
for cam_id, img_info in info['images'].items():
if 'img_path' in img_info:
if cam_id in self.data_prefix:
cam_prefix = self.data_prefix[cam_id]
else:
cam_prefix = self.data_prefix.get('img', '')
img_info['img_path'] = osp.join(cam_prefix,
img_info['img_path'])
if self.default_cam_key is not None:
info['img_path'] = info['images'][
self.default_cam_key]['img_path']
if 'lidar2cam' in info['images'][self.default_cam_key]:
info['lidar2cam'] = np.array(
info['images'][self.default_cam_key]['lidar2cam'])
if 'cam2img' in info['images'][self.default_cam_key]:
info['cam2img'] = np.array(
info['images'][self.default_cam_key]['cam2img'])
if 'lidar2img' in info['images'][self.default_cam_key]:
info['lidar2img'] = np.array(
info['images'][self.default_cam_key]['lidar2img'])
else:
info['lidar2img'] = info['cam2img'] @ info['lidar2cam']

if not self.test_mode:
# used in training
info['ann_info'] = self.parse_ann_info(info)
if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = self.parse_ann_info(info)

return info



16 changes: 16 additions & 0 deletions mmdet3d/evaluation/functional/nuscenes_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .eval import DetectionConfig, nuScenesDetectionEval
from .utils import (
class_mapping_kitti2nuscenes,
format_nuscenes_metrics,
format_nuscenes_metrics_table,
transform_det_annos_to_nusc_annos,
)

__all__ = [
"DetectionConfig",
"nuScenesDetectionEval",
"class_mapping_kitti2nuscenes",
"format_nuscenes_metrics_table",
"format_nuscenes_metrics",
"transform_det_annos_to_nusc_annos",
]
173 changes: 173 additions & 0 deletions mmdet3d/evaluation/functional/nuscenes_utils/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import json
import os
from typing import Dict, List, Optional, Tuple

import numpy as np
from nuscenes.eval.common.data_classes import EvalBox, EvalBoxes
from nuscenes.eval.common.loaders import load_prediction
from nuscenes.eval.common.utils import center_distance
from nuscenes.eval.detection.data_classes import DetectionBox
from nuscenes.eval.detection.evaluate import DetectionEval as _DetectionEval


def toEvalBoxes(nusc_boxes: Dict[str, List[Dict]], box_cls: EvalBox = DetectionBox) -> EvalBoxes:
"""
nusc_boxes = {
"sample_token_1": [
{
"sample_token": str,
"translation": List[float], (x, y, z)
"size": List[float], (width, length, height)
"rotation": List[float], (w, x, y, z)
"velocity": List[float], (vx, vy)
"detection_name": str,
"detection_score": float,
"attribute_name": str,
},
...
],
...
}
Args:
nusc_boxes (Dict[List[Dict]]): [description]
box_cls (EvalBox, optional): [description]. Defaults to DetectionBox.
Returns:
EvalBoxes: [description]
"""
return EvalBoxes.deserialize(nusc_boxes, box_cls)


class DetectionConfig:
"""Data class that specifies the detection evaluation settings."""

def __init__(
self,
class_names: List[str],
class_range: Dict[str, int],
dist_fcn: str,
dist_ths: List[float],
dist_th_tp: float,
min_recall: float,
min_precision: float,
max_boxes_per_sample: float,
mean_ap_weight: int,
):

# assert set(class_range.keys()) == set(DETECTION_NAMES), "Class count mismatch."
assert dist_th_tp in dist_ths, "dist_th_tp must be in set of dist_ths."

self.class_range = class_range
self.dist_fcn = dist_fcn
self.dist_ths = dist_ths
self.dist_th_tp = dist_th_tp
self.min_recall = min_recall
self.min_precision = min_precision
self.max_boxes_per_sample = max_boxes_per_sample
self.mean_ap_weight = mean_ap_weight

self.class_names = class_names

def __eq__(self, other):
eq = True
for key in self.serialize().keys():
eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
return eq

def serialize(self) -> dict:
"""Serialize instance into json-friendly format."""
return {
"class_names": self.class_names,
"class_range": self.class_range,
"dist_fcn": self.dist_fcn,
"dist_ths": self.dist_ths,
"dist_th_tp": self.dist_th_tp,
"min_recall": self.min_recall,
"min_precision": self.min_precision,
"max_boxes_per_sample": self.max_boxes_per_sample,
"mean_ap_weight": self.mean_ap_weight,
}

@classmethod
def deserialize(cls, content: dict):
"""Initialize from serialized dictionary."""
return cls(
content["class_names"],
content["class_range"],
content["dist_fcn"],
content["dist_ths"],
content["dist_th_tp"],
content["min_recall"],
content["min_precision"],
content["max_boxes_per_sample"],
content["mean_ap_weight"],
)

@property
def dist_fcn_callable(self):
"""Return the distance function corresponding to the dist_fcn string."""
if self.dist_fcn == "center_distance":
return center_distance
else:
raise Exception("Error: Unknown distance function %s!" % self.dist_fcn)


class nuScenesDetectionEval(_DetectionEval):
"""
This is the official nuScenes detection evaluation code.
Results are written to the provided output_dir.
nuScenes uses the following detection metrics:
- Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds.
- True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors.
- nuScenes Detection Score (NDS): The weighted sum of the above.
Here is an overview of the functions in this method:
- init: Loads GT annotations and predictions stored in JSON format and filters the boxes.
- run: Performs evaluation and dumps the metric data to disk.
- render: Renders various plots and dumps to disk.
We assume that:
- Every sample_token is given in the results, although there may be not predictions for that sample.
Please see https://www.nuscenes.org/object-detection for more details.
"""

def __init__(
self,
config: DetectionConfig,
result_boxes: Dict,
gt_boxes: Dict,
meta: Dict,
eval_set: str,
output_dir: Optional[str] = None,
verbose: bool = True,
):
"""
Initialize a DetectionEval object.
:param config: A DetectionConfig object.
:param result_boxes: result bounding boxes.
:param gt_boxes: ground-truth bounding boxes.
:param eval_set: The dataset split to evaluate on, e.g. train, val or test.
:param output_dir: Folder to save plots and results to.
:param verbose: Whether to print to stdout.
"""
self.cfg = config
self.meta = meta
self.eval_set = eval_set
self.output_dir = output_dir
self.verbose = verbose

# Make dirs.
self.plot_dir = os.path.join(self.output_dir, "plots")
if not os.path.isdir(self.output_dir):
os.makedirs(self.output_dir)
if not os.path.isdir(self.plot_dir):
os.makedirs(self.plot_dir)

self.pred_boxes: EvalBoxes = toEvalBoxes(result_boxes)
self.gt_boxes: EvalBoxes = toEvalBoxes(gt_boxes)

assert set(self.pred_boxes.sample_tokens) == set(
self.gt_boxes.sample_tokens
), "Samples in split doesn't match samples in predictions."

self.sample_tokens = self.gt_boxes.sample_tokens
Loading

0 comments on commit 5bd5365

Please sign in to comment.