Skip to content

Commit

Permalink
feat: use project-style
Browse files Browse the repository at this point in the history
Signed-off-by: Kaan Çolak <[email protected]>
  • Loading branch information
kaancolak committed Jan 22, 2024
1 parent 5c0613b commit ff616ff
Show file tree
Hide file tree
Showing 18 changed files with 4,243 additions and 0 deletions.
Empty file.
3 changes: 3 additions & 0 deletions projects/AutowareCenterPoint/centerpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pillar_encoder_autoware import PillarFeatureNetAutoware

__all__ = [ 'PillarFeatureNetAutoware']
165 changes: 165 additions & 0 deletions projects/AutowareCenterPoint/centerpoint/pillar_encoder_autoware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Optional, Tuple

import torch
from mmcv.cnn import build_norm_layer
from mmcv.ops import DynamicScatter
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.models.voxel_encoders.utils import PFNLayer, get_paddings_indicator
@MODELS.register_module()
class PillarFeatureNetAutoware(nn.Module):
"""Pillar Feature Net.
The network prepares the pillar features and performs forward pass
through PFNLayers.
Args:
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
feat_channels (tuple, optional): Number of features in each of the
N PFNLayers. Defaults to (64, ).
with_distance (bool, optional): Whether to include Euclidean distance
to points. Defaults to False.
with_cluster_center (bool, optional): [description]. Defaults to True.
with_voxel_center (bool, optional): [description]. Defaults to True.
voxel_size (tuple[float], optional): Size of voxels, only utilize x
and y size. Defaults to (0.2, 0.2, 4).
point_cloud_range (tuple[float], optional): Point cloud range, only
utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
norm_cfg ([type], optional): [description].
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
legacy (bool, optional): Whether to use the new behavior or
the original behavior. Defaults to True.
"""

def __init__(self,
in_channels: Optional[int] = 4,
feat_channels: Optional[tuple] = (64,),
with_distance: Optional[bool] = False,
with_cluster_center: Optional[bool] = True,
with_voxel_center: Optional[bool] = True,
voxel_size: Optional[Tuple[float]] = (0.2, 0.2, 4),
point_cloud_range: Optional[Tuple[float]] = (0, -40, -3, 70.4,
40, 1),
norm_cfg: Optional[dict] = dict(
type='BN1d', eps=1e-3, momentum=0.01),
mode: Optional[str] = 'max',
legacy: Optional[bool] = True,
use_voxel_center_z: Optional[bool] = True, ):
super(PillarFeatureNetAutoware, self).__init__()
assert len(feat_channels) > 0
self.legacy = legacy
self.use_voxel_center_z = use_voxel_center_z
if with_cluster_center:
in_channels += 3
if with_voxel_center:
in_channels += 2
if self.use_voxel_center_z:
in_channels += 1
if with_distance:
in_channels += 1
self._with_distance = with_distance
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
# Create PillarFeatureNet layers
self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels)
pfn_layers = []
for i in range(len(feat_channels) - 1):
in_filters = feat_channels[i]
out_filters = feat_channels[i + 1]
if i < len(feat_channels) - 2:
last_layer = False
else:
last_layer = True
pfn_layers.append(
PFNLayer(
in_filters,
out_filters,
norm_cfg=norm_cfg,
last_layer=last_layer,
mode=mode))
self.pfn_layers = nn.ModuleList(pfn_layers)

# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0]
self.vy = voxel_size[1]
self.vz = voxel_size[2]
self.x_offset = self.vx / 2 + point_cloud_range[0]
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.z_offset = self.vz / 2 + point_cloud_range[2]
self.point_cloud_range = point_cloud_range

def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
*args, **kwargs) -> Tensor:
"""Forward function.
Args:
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.
Returns:
torch.Tensor: Features of pillars.
"""
features_ls = [features]
# Find distance of x, y, and z from cluster center
if self._with_cluster_center:
points_mean = features[:, :, :3].sum(
dim=1, keepdim=True) / num_points.type_as(features).view(
-1, 1, 1)
f_cluster = features[:, :, :3] - points_mean
features_ls.append(f_cluster)

# Find distance of x, y, and z from pillar center
dtype = features.dtype
if self._with_voxel_center:
center_feature_size = 3 if self.use_voxel_center_z else 2
if not self.legacy:
f_center = torch.zeros_like(features[:, :, :center_feature_size])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
self.y_offset)
if self.use_voxel_center_z:
f_center[:, :, 2] = features[:, :, 2] - (
coors[:, 1].to(dtype).unsqueeze(1) * self.vz +
self.z_offset)
else:
f_center = features[:, :, :center_feature_size]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
if self.use_voxel_center_z:
f_center[:, :, 2] = f_center[:, :, 2] - (
coors[:, 1].type_as(features).unsqueeze(1) * self.vz +
self.z_offset)
features_ls.append(f_center)

if self._with_distance:
points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
features_ls.append(points_dist)

# Combine together feature decorations
features = torch.cat(features_ls, dim=-1)
# The feature decorations were calculated without regard to whether
# pillar was empty. Need to ensure that
# empty pillars remain set to zeros.
voxel_count = features.shape[1]
mask = get_paddings_indicator(num_points, voxel_count, axis=0)
mask = torch.unsqueeze(mask, -1).type_as(features)
features *= mask

for pfn in self.pfn_layers:
features = pfn(features, num_points)

return features.squeeze(1)
Loading

0 comments on commit ff616ff

Please sign in to comment.