From be583bcc247c3737153011e6fa31cd311a1e71cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kaan=20=C3=87olak?= Date: Tue, 10 Oct 2023 17:10:21 +0300 Subject: [PATCH] pp refactor and add ONNX converter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kaan Çolak --- configs/centerpoint/centerpoint_custom.py | 181 ++++++++++++ .../models/voxel_encoders/pillar_encoder.py | 64 ++-- tools/centerpoint_onnx_converter.py | 273 ++++++++++++++++++ 3 files changed, 490 insertions(+), 28 deletions(-) create mode 100755 configs/centerpoint/centerpoint_custom.py create mode 100644 tools/centerpoint_onnx_converter.py diff --git a/configs/centerpoint/centerpoint_custom.py b/configs/centerpoint/centerpoint_custom.py new file mode 100755 index 0000000000..7095ba3ccd --- /dev/null +++ b/configs/centerpoint/centerpoint_custom.py @@ -0,0 +1,181 @@ +_base_ = [ + '../_base_/datasets/nus-3d.py', + '../_base_/models/centerpoint_pillar02_second_secfpn_nus.py', + '../_base_/schedules/cyclic-20e.py', '../_base_/default_runtime.py' +] + +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] + +class_names = [ + "car", + "truck", + "bus", + "bicycle", + "pedestrian", +] + +data_prefix = dict(pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP') + +out_size_factor = 1 +model = dict( + data_preprocessor=dict( + voxel_layer=dict( + point_cloud_range=point_cloud_range)), + pts_voxel_encoder=dict( + point_cloud_range=point_cloud_range, + in_channels=4, + feat_channels=[32, 32], + use_voxel_center_z=False), + pts_middle_encoder=dict( + in_channels=32), + pts_backbone=dict( + in_channels=32, + layer_strides=[1, 2, 2],), + pts_neck=dict( + upsample_strides=[1, 2, 4], ), + pts_bbox_head=dict( + tasks=[dict( + num_class=len(class_names), + class_names=class_names)], + bbox_coder=dict( + out_size_factor=out_size_factor, + pc_range=point_cloud_range[:2])), + # model training and testing settings + train_cfg=dict( + pts=dict( + point_cloud_range=point_cloud_range, + out_size_factor=out_size_factor)), + test_cfg=dict( + pts=dict( + pc_range=point_cloud_range[:2], + nms_type='circle', + out_size_factor=out_size_factor,))) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +backend_args = None + +point_load_dim = 5 +point_use_dim = [0, 1, 2, 4] + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + bus=4, + bicycle=6, + pedestrian=2), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=point_load_dim, + use_dim=point_use_dim, + backend_args=backend_args), + backend_args=backend_args) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=point_load_dim, + use_dim=5, + backend_args=backend_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=point_use_dim, + pad_empty_sweeps=True, + remove_close=True, + backend_args=backend_args), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=point_load_dim, + use_dim=5, + backend_args=backend_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=point_use_dim, + pad_empty_sweeps=True, + remove_close=True, + backend_args=backend_args), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D') + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] + +train_dataloader = dict( + _delete_=True, + batch_size=2, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='nuscenes_infos_train.pkl', + pipeline=train_pipeline, + metainfo=dict(classes=class_names), + test_mode=False, + data_prefix=data_prefix, + use_valid_flag=True, + box_type_3d='LiDAR', + backend_args=backend_args))) +test_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names))) +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names))) + +train_cfg = dict(val_interval=5) +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=1, + save_optimizer=True)) diff --git a/mmdet3d/models/voxel_encoders/pillar_encoder.py b/mmdet3d/models/voxel_encoders/pillar_encoder.py index bf0189cfb2..a3d0ea2e69 100644 --- a/mmdet3d/models/voxel_encoders/pillar_encoder.py +++ b/mmdet3d/models/voxel_encoders/pillar_encoder.py @@ -40,7 +40,7 @@ class PillarFeatureNet(nn.Module): def __init__(self, in_channels: Optional[int] = 4, - feat_channels: Optional[tuple] = (64, ), + feat_channels: Optional[tuple] = (64,), with_distance: Optional[bool] = False, with_cluster_center: Optional[bool] = True, with_voxel_center: Optional[bool] = True, @@ -50,14 +50,18 @@ def __init__(self, norm_cfg: Optional[dict] = dict( type='BN1d', eps=1e-3, momentum=0.01), mode: Optional[str] = 'max', - legacy: Optional[bool] = True): + legacy: Optional[bool] = True, + use_voxel_center_z: Optional[bool] = True, ): super(PillarFeatureNet, self).__init__() assert len(feat_channels) > 0 self.legacy = legacy + self.use_voxel_center_z = use_voxel_center_z if with_cluster_center: in_channels += 3 if with_voxel_center: - in_channels += 3 + in_channels += 2 + if self.use_voxel_center_z: + in_channels += 1 if with_distance: in_channels += 1 self._with_distance = with_distance @@ -110,35 +114,38 @@ def forward(self, features: Tensor, num_points: Tensor, coors: Tensor, if self._with_cluster_center: points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_points.type_as(features).view( - -1, 1, 1) + -1, 1, 1) f_cluster = features[:, :, :3] - points_mean features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center dtype = features.dtype if self._with_voxel_center: + center_feature_size = 3 if self.use_voxel_center_z else 2 if not self.legacy: - f_center = torch.zeros_like(features[:, :, :3]) + f_center = torch.zeros_like(features[:, :, :center_feature_size]) f_center[:, :, 0] = features[:, :, 0] - ( - coors[:, 3].to(dtype).unsqueeze(1) * self.vx + - self.x_offset) + coors[:, 3].to(dtype).unsqueeze(1) * self.vx + + self.x_offset) f_center[:, :, 1] = features[:, :, 1] - ( - coors[:, 2].to(dtype).unsqueeze(1) * self.vy + - self.y_offset) - f_center[:, :, 2] = features[:, :, 2] - ( - coors[:, 1].to(dtype).unsqueeze(1) * self.vz + - self.z_offset) + coors[:, 2].to(dtype).unsqueeze(1) * self.vy + + self.y_offset) + if self.use_voxel_center_z: + f_center[:, :, 2] = features[:, :, 2] - ( + coors[:, 1].to(dtype).unsqueeze(1) * self.vz + + self.z_offset) else: - f_center = features[:, :, :3] + f_center = features[:, :, :center_feature_size] f_center[:, :, 0] = f_center[:, :, 0] - ( - coors[:, 3].type_as(features).unsqueeze(1) * self.vx + - self.x_offset) + coors[:, 3].type_as(features).unsqueeze(1) * self.vx + + self.x_offset) f_center[:, :, 1] = f_center[:, :, 1] - ( - coors[:, 2].type_as(features).unsqueeze(1) * self.vy + - self.y_offset) - f_center[:, :, 2] = f_center[:, :, 2] - ( - coors[:, 1].type_as(features).unsqueeze(1) * self.vz + - self.z_offset) + coors[:, 2].type_as(features).unsqueeze(1) * self.vy + + self.y_offset) + if self.use_voxel_center_z: + f_center[:, :, 2] = f_center[:, :, 2] - ( + coors[:, 1].type_as(features).unsqueeze(1) * self.vz + + self.z_offset) features_ls.append(f_center) if self._with_distance: @@ -193,7 +200,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet): def __init__(self, in_channels: Optional[int] = 4, - feat_channels: Optional[tuple] = (64, ), + feat_channels: Optional[tuple] = (64,), with_distance: Optional[bool] = False, with_cluster_center: Optional[bool] = True, with_voxel_center: Optional[bool] = True, @@ -264,15 +271,15 @@ def map_voxel_center_to_point(self, pts_coors: Tensor, voxel_mean: Tensor, canvas = voxel_mean.new_zeros(canvas_channel, canvas_len) # Only include non-empty pillars indices = ( - voxel_coors[:, 0] * canvas_y * canvas_x + - voxel_coors[:, 2] * canvas_x + voxel_coors[:, 3]) + voxel_coors[:, 0] * canvas_y * canvas_x + + voxel_coors[:, 2] * canvas_x + voxel_coors[:, 3]) # Scatter the blob back to the canvas canvas[:, indices.long()] = voxel_mean.t() # Step 2: get voxel mean for each point voxel_index = ( - pts_coors[:, 0] * canvas_y * canvas_x + - pts_coors[:, 2] * canvas_x + pts_coors[:, 3]) + pts_coors[:, 0] * canvas_y * canvas_x + + pts_coors[:, 2] * canvas_x + pts_coors[:, 3]) center_per_point = canvas[:, voxel_index.long()].t() return center_per_point @@ -301,11 +308,11 @@ def forward(self, features: Tensor, coors: Tensor) -> Tensor: if self._with_voxel_center: f_center = features.new_zeros(size=(features.size(0), 3)) f_center[:, 0] = features[:, 0] - ( - coors[:, 3].type_as(features) * self.vx + self.x_offset) + coors[:, 3].type_as(features) * self.vx + self.x_offset) f_center[:, 1] = features[:, 1] - ( - coors[:, 2].type_as(features) * self.vy + self.y_offset) + coors[:, 2].type_as(features) * self.vy + self.y_offset) f_center[:, 2] = features[:, 2] - ( - coors[:, 1].type_as(features) * self.vz + self.z_offset) + coors[:, 1].type_as(features) * self.vz + self.z_offset) features_ls.append(f_center) if self._with_distance: @@ -324,3 +331,4 @@ def forward(self, features: Tensor, coors: Tensor) -> Tensor: features = torch.cat([point_feats, feat_per_point], dim=1) return voxel_feats, voxel_coors + diff --git a/tools/centerpoint_onnx_converter.py b/tools/centerpoint_onnx_converter.py new file mode 100644 index 0000000000..0d59d4d131 --- /dev/null +++ b/tools/centerpoint_onnx_converter.py @@ -0,0 +1,273 @@ +import argparse +import os +import torch +from typing import Any, Dict, Optional, Tuple, List +from mmengine import Config +from mmengine.registry import MODELS, Registry +from mmengine.runner import Runner +from mmdet3d.apis import init_model +from mmdet3d.models.dense_heads.centerpoint_head import SeparateHead, CenterHead +from mmdet3d.models.voxel_encoders.pillar_encoder import PillarFeatureNet +from mmdet3d.models.voxel_encoders.utils import get_paddings_indicator + + +def parse_args(): + parser = argparse.ArgumentParser(description='Create autoware compitable onnx file from torch checkpoint ') + parser.add_argument('--cfg', help='train config file path') + parser.add_argument('--ckpt', help='checkpoint weeight') + parser.add_argument('--work-dir', help='the dir to save onnx files') + + args = parser.parse_args() + return args + + +class CenterPointToONNX(object): + def __init__( + self, + config: Config, + checkpoint_path: Optional[str] = None, + output_path: Optional[str] = None, + ): + assert isinstance(config, Config), f"expected `mmcv.Config`, but got {type(config)}" + _, ext = os.path.splitext(checkpoint_path) + assert ext == ".pth", f"expected .pth model, but got {ext}" + + self.config = config + self.checkpoint_path = checkpoint_path + + os.makedirs(output_path, exist_ok=True) + self.output_path = output_path + + def save_onnx(self) -> None: + # Overwrite models with Autoware's TensorRT compatible versions + self.config.model.pts_voxel_encoder.type = "PillarFeatureNetONNX" + self.config.model.pts_bbox_head.type = "CenterHeadONNX" + self.config.model.pts_bbox_head.separate_head.type = "SeparateHeadONNX" + + model = init_model(self.config, self.checkpoint_path, device="cuda:0") + dataloader = Runner.build_dataloader(self.config.test_dataloader) + batch_dict = next(iter(dataloader)) + + voxel_dict = model.data_preprocessor.voxelize(batch_dict["inputs"]["points"], batch_dict) + input_features = model.pts_voxel_encoder.get_input_features(voxel_dict["voxels"], voxel_dict["num_points"], + voxel_dict["coors"]).to("cuda:0") + + # CenterPoint's PointPillar voxel encoder ONNX conversion + pth_onnx_pve = os.path.join(self.output_path, "pts_voxel_encoder_centerpoint_custom.onnx") + torch.onnx.export( + model.pts_voxel_encoder, + (input_features,), + f=pth_onnx_pve, + input_names=("input_features",), + output_names=("pillar_features",), + dynamic_axes={ + "input_features": {0: "num_voxels", 1: "num_max_points"}, + "pillar_features": {0: "num_voxels"}, + }, + verbose=False, + opset_version=11, + ) + print(f"Saved pts_voxel_encoder onnx model: {pth_onnx_pve}") + + voxel_features = model.pts_voxel_encoder(input_features) + voxel_features = voxel_features.squeeze() + + batch_size = voxel_dict["coors"][-1, 0] + 1 + x = model.pts_middle_encoder(voxel_features, voxel_dict["coors"], batch_size) + + # CenterPoint backbone's to neck ONNX conversion + pts_backbone_neck_head = CenterPointHeadONNX( + model.pts_backbone, + model.pts_neck, + model.pts_bbox_head, + ) + + pth_onnx_backbone_neck_head = os.path.join(self.output_path, "pts_backbone_neck_head_centerpoint_custom.onnx") + torch.onnx.export( + pts_backbone_neck_head, + (x,), + f=pth_onnx_backbone_neck_head, + input_names=("spatial_features",), + output_names=tuple(model.pts_bbox_head.output_names), + dynamic_axes={ + name: {0: "batch_size", 2: "H", 3: "W"} + for name in ["spatial_features"] + model.pts_bbox_head.output_names + }, + verbose=False, + opset_version=11, + ) + print(f"Saved pts_backbone_neck_head onnx model: {pth_onnx_backbone_neck_head}") + + +@MODELS.register_module() +class PillarFeatureNetONNX(PillarFeatureNet): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_input_features( + self, features: torch.Tensor, num_points: torch.Tensor, coors: torch.Tensor + ) -> torch.Tensor: + """Forward function. + Args: + features (torch.Tensor): Point features or raw points in shape + (N, M, C). + num_points (torch.Tensor): Number of points in each pillar. + coors (torch.Tensor): Coordinates of each voxel. + Returns: + torch.Tensor: Features of pillars. + """ + + features_ls = [features] + # Find distance of x, y, and z from cluster center + if self._with_cluster_center: + points_mean = features[:, :, :3].sum( + dim=1, keepdim=True) / num_points.type_as(features).view( + -1, 1, 1) + f_cluster = features[:, :, :3] - points_mean + features_ls.append(f_cluster) + + # Find distance of x, y, and z from pillar center + dtype = features.dtype + if self._with_voxel_center: + center_feature_size = 3 if self.use_voxel_center_z else 2 + if not self.legacy: + f_center = torch.zeros_like(features[:, :, :center_feature_size]) + f_center[:, :, 0] = features[:, :, 0] - ( + coors[:, 3].to(dtype).unsqueeze(1) * self.vx + + self.x_offset) + f_center[:, :, 1] = features[:, :, 1] - ( + coors[:, 2].to(dtype).unsqueeze(1) * self.vy + + self.y_offset) + if self.use_voxel_center_z: + f_center[:, :, 2] = features[:, :, 2] - ( + coors[:, 1].to(dtype).unsqueeze(1) * self.vz + + self.z_offset) + else: + f_center = features[:, :, :center_feature_size] + f_center[:, :, 0] = f_center[:, :, 0] - ( + coors[:, 3].type_as(features).unsqueeze(1) * self.vx + + self.x_offset) + f_center[:, :, 1] = f_center[:, :, 1] - ( + coors[:, 2].type_as(features).unsqueeze(1) * self.vy + + self.y_offset) + if self.use_voxel_center_z: + f_center[:, :, 2] = f_center[:, :, 2] - ( + coors[:, 1].type_as(features).unsqueeze(1) * self.vz + + self.z_offset) + features_ls.append(f_center) + + if self._with_distance: + points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) + features_ls.append(points_dist) + + features = torch.cat(features_ls, dim=-1) + + voxel_count = features.shape[1] + mask = get_paddings_indicator(num_points, voxel_count, axis=0) + mask = torch.unsqueeze(mask, -1).type_as(features) + features *= mask + + return features + + def forward( + self, + features: torch.Tensor, + ) -> torch.Tensor: + """Forward function. + Args: + features (torch.Tensor): Point features in shape (N, M, C). + num_points (torch.Tensor): Number of points in each pillar. + coors (torch.Tensor): + Returns: + torch.Tensor: Features of pillars. + """ + + for pfn in self.pfn_layers: + features = pfn(features) + + return features + + +@MODELS.register_module() +class SeparateHeadONNX(SeparateHead): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Order output's of the heads + rot_heads = {k: None for k in self.heads.keys() if "rot" in k} + self.heads: Dict[str, None] = { + "heatmap": None, + "reg": None, + "height": None, + "dim": None, + **rot_heads, + "vel": None, + } + + +@MODELS.register_module() +class CenterHeadONNX(CenterHead): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.task_heads: List[SeparateHeadONNX] + self.output_names: List[str] = list(self.task_heads[0].heads.keys()) + + def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Args: + x (List[torch.Tensor]): multi-level features + Returns: + pred (Tuple[torch.Tensor]): Output results for tasks. + """ + assert len(x) == 1, "The input of CenterHeadONNX must be a single-level feature" + + x = self.shared_conv(x[0]) + head_tensors: Dict[str, torch.Tensor] = self.task_heads[0](x) + + ret_list: List[torch.Tensor] = list() + for head_name in self.output_names: + ret_list.append(head_tensors[head_name]) + + return tuple(ret_list) + + +class CenterPointHeadONNX(torch.nn.Module): + + def __init__(self, backbone: torch.nn.Module, neck: torch.nn.Module, bbox_head: torch.nn.Module): + super(CenterPointHeadONNX, self).__init__() + self.backbone: torch.nn.Module = backbone + self.neck: torch.nn.Module = neck + self.bbox_head: torch.nn.Module = bbox_head + + def forward(self, x: torch.Tensor) -> Tuple[List[Dict[str, torch.Tensor]]]: + """ + Args: + x (torch.Tensor): (B, C, H, W) + Returns: + tuple[list[dict[str, any]]]: + (num_classes x [num_detect x {'reg', 'height', 'dim', 'rot', 'vel', 'heatmap'}]) + """ + x = self.backbone(x) + if self.neck is not None: + x = self.neck(x) + x = self.bbox_head(x) + return x + + +CUSTOM_MODEL_REGISTRY = Registry('model', parent=MODELS) +CUSTOM_MODEL_REGISTRY.register_module(module=PillarFeatureNetONNX, force=True) +CUSTOM_MODEL_REGISTRY.register_module(module=CenterHeadONNX, force=True) +CUSTOM_MODEL_REGISTRY.register_module(module=SeparateHeadONNX, force=True) + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.cfg) + det = CenterPointToONNX(config=cfg, checkpoint_path=args.ckpt, output_path=args.work_dir) + det.save_onnx() + + +if __name__ == '__main__': + main()