Shortcuts

Source code for mmrotate.models.dense_heads.sam_reppoints_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2d, min_area_polygons
from mmcv.runner import force_fp32
from mmdet.core import images_to_levels, multi_apply, unmap
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.core.utils import select_single_mlvl
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead

from mmrotate.core import (build_assigner, build_sampler,
                           multiclass_nms_rotated, obb2poly, poly2obb)
from ..builder import ROTATED_HEADS, build_loss
from .utils import get_num_level_anchors_inside, points_center_pts


[docs]@ROTATED_HEADS.register_module() class SAMRepPointsHead(BaseDenseHead): """Rotated RepPoints head for SASM. Args: num_classes (int): Number of classes. in_channels (int): Number of input channels. feat_channels (int): Number of feature channels. point_feat_channels (int, optional): Number of channels of points features. stacked_convs (int, optional): Number of stacked convolutions. num_points (int, optional): Number of points in points set. gradient_mul (float, optional): The multiplier to gradients from points refinement and recognition. point_strides (Iterable, optional): points strides. point_base_scale (int, optional): Bbox scale for assigning labels. conv_bias (str, optional): The bias of convolution. loss_cls (dict, optional): Config of classification loss. loss_bbox_init (dict, optional): Config of initial points loss. loss_bbox_refine (dict, optional): Config of points loss in refinement. conv_cfg (dict, optional): The config of convolution. norm_cfg (dict, optional): The config of normlization. train_cfg (dict, optional): The config of train. test_cfg (dict, optional): The config of test. center_init (bool, optional): Whether to use center point assignment. transform_method (str, optional): The methods to transform RepPoints to bbox. topk (int, optional): Number of the highest topk points. Defaults to 9. anti_factor (float, optional): Feature anti-aliasing coefficient. version (str, optional): Angle representations. Defaults to 'oc'. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, num_classes, in_channels, feat_channels, point_feat_channels=256, stacked_convs=3, num_points=9, gradient_mul=0.1, point_strides=[8, 16, 32, 64, 128], point_base_scale=4, conv_bias='auto', loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), loss_bbox_init=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), loss_bbox_refine=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), conv_cfg=None, norm_cfg=None, train_cfg=None, test_cfg=None, center_init=True, transform_method='rotrect', topk=6, anti_factor=0.75, version='oc', init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='reppoints_cls_out', std=0.01, bias_prob=0.01)), **kwargs): super(SAMRepPointsHead, self).__init__(init_cfg) self.num_points = num_points self.point_feat_channels = point_feat_channels self.center_init = center_init # we use deform conv to extract points features self.dcn_kernel = int(np.sqrt(num_points)) self.dcn_pad = int((self.dcn_kernel - 1) / 2) assert self.dcn_kernel * self.dcn_kernel == num_points, \ 'The points number should be a square number.' assert self.dcn_kernel % 2 == 1, \ 'The points number should be an odd square number.' dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64) dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) dcn_base_x = np.tile(dcn_base, self.dcn_kernel) dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( (-1)) self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) self.num_classes = num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs assert conv_bias == 'auto' or isinstance(conv_bias, bool) self.conv_bias = conv_bias self.loss_cls = build_loss(loss_cls) self.train_cfg = train_cfg self.test_cfg = test_cfg self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.fp16_enabled = False self.gradient_mul = gradient_mul self.point_base_scale = point_base_scale self.point_strides = point_strides self.prior_generator = MlvlPointGenerator( self.point_strides, offset=0.) self.num_base_priors = self.prior_generator.num_base_priors[0] self.sampling = loss_cls['type'] not in ['FocalLoss'] if self.train_cfg: self.init_assigner = build_assigner(self.train_cfg.init.assigner) self.refine_assigner = build_assigner( self.train_cfg.refine.assigner) # use PseudoSampler when sampling is False if self.sampling and hasattr(self.train_cfg, 'sampler'): sampler_cfg = self.train_cfg.sampler else: sampler_cfg = dict(type='PseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) self.transform_method = transform_method self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) if self.use_sigmoid_cls: self.cls_out_channels = self.num_classes else: self.cls_out_channels = self.num_classes + 1 self.loss_bbox_init = build_loss(loss_bbox_init) self.loss_bbox_refine = build_loss(loss_bbox_refine) self.topk = topk self.anti_factor = anti_factor self.version = version self._init_layers() def _init_layers(self): """Initialize layers of the head.""" self.relu = nn.ReLU(inplace=True) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) self.reg_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) pts_out_dim = 2 * self.num_points self.reppoints_cls_conv = DeformConv2d(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad) self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, self.cls_out_channels, 1, 1, 0) self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, self.point_feat_channels, 3, 1, 1) self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0) self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad) self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0)
[docs] def points2rotrect(self, pts, y_first=True): """Convert points to oriented bboxes.""" if y_first: pts = pts.reshape(-1, self.num_points, 2) pts_dy = pts[:, :, 0::2] pts_dx = pts[:, :, 1::2] pts = torch.cat([pts_dx, pts_dy], dim=2).reshape(-1, 2 * self.num_points) if self.transform_method == 'rotrect': rotrect_pred = min_area_polygons(pts) return rotrect_pred else: raise NotImplementedError
[docs] def forward(self, feats): """Forward function.""" return multi_apply(self.forward_single, feats)
[docs] def forward_single(self, x): """Forward feature map of a single FPN level.""" dcn_base_offset = self.dcn_base_offset.type_as(x) # If we use center_init, the initial reppoints is from center points. # If we use bounding bbox representation, the initial reppoints is # from regular grid placed on a pre-defined bbox. points_init = 0 cls_feat = x pts_feat = x for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: pts_feat = reg_conv(pts_feat) # initialize reppoints pts_out_init = self.reppoints_pts_init_out( self.relu(self.reppoints_pts_init_conv(pts_feat))) pts_out_init = pts_out_init + points_init # refine and classify reppoints pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( ) + self.gradient_mul * pts_out_init dcn_offset = pts_out_init_grad_mul - dcn_base_offset cls_out = self.reppoints_cls_out( self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) pts_out_refine = self.reppoints_pts_refine_out( self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) pts_out_refine = pts_out_refine + pts_out_init.detach() return cls_out, pts_out_init, pts_out_refine
[docs] def get_points(self, featmap_sizes, img_metas, device): """Get points according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. img_metas (list[dict]): Image meta info. Returns: tuple: points of each image, valid flags of each image """ num_imgs = len(img_metas) # since feature map sizes of all images are the same, we only compute # points center for one time multi_level_points = self.prior_generator.grid_priors( featmap_sizes, device=device, with_stride=True) points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)] # for each image, we compute valid flags of multi level grids valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): multi_level_flags = self.prior_generator.valid_flags( featmap_sizes, img_meta['pad_shape']) valid_flag_list.append(multi_level_flags) return points_list, valid_flag_list
[docs] def offset_to_pts(self, center_list, pred_list): """Change from point offset to point coordinate.""" pts_list = [] for i_lvl, _ in enumerate(self.point_strides): pts_lvl = [] for i_img, _ in enumerate(center_list): pts_center = center_list[i_img][i_lvl][:, :2].repeat( 1, self.num_points) pts_shift = pred_list[i_lvl][i_img] yx_pts_shift = pts_shift.permute(1, 2, 0).view( -1, 2 * self.num_points) y_pts_shift = yx_pts_shift[..., 0::2] x_pts_shift = yx_pts_shift[..., 1::2] xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1) xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center pts_lvl.append(pts) pts_lvl = torch.stack(pts_lvl, 0) pts_list.append(pts_lvl) return pts_list
def _point_target_single(self, flat_proposals, num_level_proposals, valid_flags, gt_bboxes, gt_bboxes_ignore, gt_labels, overlaps, stage='init', unmap_outputs=True): """Single point target function.""" inside_flags = valid_flags if not inside_flags.any(): return (None, ) * 9 # assign gt and sample proposals proposals = flat_proposals[inside_flags, :] num_level_anchors_inside = get_num_level_anchors_inside( num_level_proposals, inside_flags) # convert gt from obb to poly gt_bboxes = obb2poly(gt_bboxes, self.version) if stage == 'init': assigner = self.init_assigner pos_weight = self.train_cfg.init.pos_weight assign_result = assigner.assign( proposals, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels, overlaps) else: assigner = self.refine_assigner pos_weight = self.train_cfg.refine.pos_weight if self.train_cfg.refine.assigner.type not in ( 'ATSSAssigner', 'ATSSConvexAssigner', 'SASAssigner'): assign_result = assigner.assign( proposals, gt_bboxes, overlaps, gt_bboxes_ignore, None if self.sampling else gt_labels) else: assign_result = assigner.assign( proposals, num_level_anchors_inside, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels) sampling_result = self.sampler.sample(assign_result, proposals, gt_bboxes) gt_inds = assign_result.gt_inds num_valid_proposals = proposals.shape[0] bbox_gt = proposals.new_zeros([num_valid_proposals, 8]) pos_proposals = torch.zeros_like(proposals) proposals_weights = proposals.new_zeros(num_valid_proposals) labels = proposals.new_full((num_valid_proposals, ), self.num_classes, dtype=torch.long) label_weights = proposals.new_zeros( num_valid_proposals, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: pos_gt_bboxes = sampling_result.pos_gt_bboxes bbox_gt[pos_inds, :] = pos_gt_bboxes pos_proposals[pos_inds, :] = proposals[pos_inds, :] proposals_weights[pos_inds] = 1.0 if gt_labels is None: # Only rpn gives gt_labels as None # Foreground is the first class labels[pos_inds] = 0 else: labels[pos_inds] = gt_labels[ sampling_result.pos_assigned_gt_inds] if pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # use la rbboxes_center, width, height, angles = torch.split( poly2obb(bbox_gt, self.version), [2, 1, 1, 1], dim=-1) if stage == 'init': points_xy = pos_proposals[:, :2] else: points_xy = points_center_pts(pos_proposals, y_first=True) distances = torch.zeros_like(angles).reshape(-1) angles_index_wh = ((width != 0) & (angles >= 0) & (angles <= 1.57)).squeeze() angles_index_hw = ((width != 0) & ((angles < 0) | (angles > 1.57))).squeeze() # 01_la:compution of distance distances[angles_index_wh] = torch.sqrt( (torch.pow( rbboxes_center[angles_index_wh, 0] - points_xy[angles_index_wh, 0], 2) / width[angles_index_wh].squeeze()) + (torch.pow( rbboxes_center[angles_index_wh, 1] - points_xy[angles_index_wh, 1], 2) / height[angles_index_wh].squeeze())) distances[angles_index_hw] = torch.sqrt( (torch.pow( rbboxes_center[angles_index_hw, 0] - points_xy[angles_index_hw, 0], 2) / height[angles_index_hw].squeeze()) + (torch.pow( rbboxes_center[angles_index_hw, 1] - points_xy[angles_index_hw, 1], 2) / width[angles_index_hw].squeeze())) distances[distances == float('nan')] = 0. sam_weights = label_weights * (torch.exp(1 / (distances + 1))) sam_weights[sam_weights == float('inf')] = 0. # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) labels = unmap(labels, num_total_proposals, inside_flags) label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags) proposals_weights = unmap(proposals_weights, num_total_proposals, inside_flags) gt_inds = unmap(gt_inds, num_total_proposals, inside_flags) sam_weights = unmap(sam_weights, num_total_proposals, inside_flags) return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, pos_inds, neg_inds, gt_inds, sam_weights)
[docs] def get_targets(self, proposals_list, valid_flag_list, gt_bboxes_list, img_metas, gt_bboxes_ignore_list=None, gt_labels_list=None, stage='init', label_channels=1, unmap_outputs=True): """Compute corresponding GT box and classification targets for proposals. Args: proposals_list (list[list]): Multi level points/bboxes of each image. valid_flag_list (list[list]): Multi level valid flags of each image. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. img_metas (list[dict]): Meta info of each image. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. gt_bboxes_list (list[Tensor]): Ground truth labels of each box. stage (str): `init` or `refine`. Generate target for init stage or refine stage label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple (list[Tensor]): - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of each \ level. - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. - proposal_list (list[Tensor]): Proposals(points/bboxes) of \ each level. - proposal_weights_list (list[Tensor]): Proposal weights of \ each level. - num_total_pos (int): Number of positive samples in all \ images. - num_total_neg (int): Number of negative samples in all \ images. """ assert stage in ['init', 'refine'] num_imgs = len(img_metas) assert len(proposals_list) == len(valid_flag_list) == num_imgs # points number of multi levels num_level_proposals = [points.size(0) for points in proposals_list[0]] num_level_proposals_list = [num_level_proposals] * num_imgs # concat all level points and flags to a single tensor for i in range(num_imgs): assert len(proposals_list[i]) == len(valid_flag_list[i]) proposals_list[i] = torch.cat(proposals_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i]) # compute targets for each image if gt_bboxes_ignore_list is None: gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] len_gt_labels = len(gt_bboxes_list) all_overlaps_rotate_list = [None] * len_gt_labels (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds_list, neg_inds_list, all_gt_inds_list, all_sam_init_weights) = multi_apply( self._point_target_single, proposals_list, num_level_proposals_list, valid_flag_list, gt_bboxes_list, gt_bboxes_ignore_list, gt_labels_list, all_overlaps_rotate_list, stage=stage, unmap_outputs=unmap_outputs) # no valid points if any([labels is None for labels in all_labels]): return None # sampled points of all images num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) labels_list = images_to_levels(all_labels, num_level_proposals) label_weights_list = images_to_levels(all_label_weights, num_level_proposals) bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) proposals_list = images_to_levels(all_proposals, num_level_proposals) proposal_weights_list = images_to_levels(all_proposal_weights, num_level_proposals) gt_inds_list = images_to_levels(all_gt_inds_list, num_level_proposals) sam_init_weights_list = images_to_levels(all_sam_init_weights, num_level_proposals) return (labels_list, label_weights_list, bbox_gt_list, proposals_list, proposal_weights_list, num_total_pos, num_total_neg, gt_inds_list, sam_init_weights_list)
[docs] def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels, label_weights, rbbox_gt_init, convex_weights_init, sam_weights_init, rbbox_gt_refine, convex_weights_refine, sam_weights_refine, stride, num_total_samples_refine): """Single loss function.""" normalize_term = self.point_base_scale * stride rbbox_gt_init = rbbox_gt_init.reshape(-1, 8) convex_weights_init = convex_weights_init.reshape(-1) sam_weights_init = sam_weights_init.reshape(-1) # init points loss pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) pos_ind_init = (convex_weights_init > 0).nonzero( as_tuple=False).reshape(-1) pts_pred_init_norm = pts_pred_init[pos_ind_init] rbbox_gt_init_norm = rbbox_gt_init[pos_ind_init] convex_weights_pos_init = convex_weights_init[pos_ind_init] sam_weights_pos_init = sam_weights_init[pos_ind_init] loss_pts_init = self.loss_bbox_init( pts_pred_init_norm / normalize_term, rbbox_gt_init_norm / normalize_term, convex_weights_pos_init * sam_weights_pos_init) # refine points loss rbbox_gt_refine = rbbox_gt_refine.reshape(-1, 8) pts_pred_refine = pts_pred_refine.reshape(-1, 2 * self.num_points) convex_weights_refine = convex_weights_refine.reshape(-1) sam_weights_refine = sam_weights_refine.reshape(-1) pos_ind_refine = (convex_weights_refine > 0).nonzero( as_tuple=False).reshape(-1) pts_pred_refine_norm = pts_pred_refine[pos_ind_refine] rbbox_gt_refine_norm = rbbox_gt_refine[pos_ind_refine] convex_weights_pos_refine = convex_weights_refine[pos_ind_refine] sam_weights_pos_refine = sam_weights_refine[pos_ind_refine] loss_pts_refine = self.loss_bbox_refine( pts_pred_refine_norm / normalize_term, rbbox_gt_refine_norm / normalize_term, convex_weights_pos_refine * sam_weights_pos_refine) # classification loss labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) loss_cls = self.loss_cls( cls_score, labels, label_weights * sam_weights_refine, avg_factor=num_total_samples_refine) return loss_cls, loss_pts_init, loss_pts_refine
[docs] def loss(self, cls_scores, pts_preds_init, pts_preds_refine, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Loss function of SAM RepPoints head.""" featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 device = cls_scores[0].device # target for initial stage center_list, valid_flag_list = self.get_points( featmap_sizes, img_metas, device=device) pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init) if self.train_cfg.init.assigner['type'] == 'ConvexAssigner': candidate_list = center_list else: raise NotImplementedError cls_reg_targets_init = self.get_targets( candidate_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, stage='init', label_channels=label_channels) (*_, rbbox_gt_list_init, candidate_list_init, convex_weights_list_init, num_total_pos_init, num_total_neg_init, gt_inds_init, sam_weights_list_init) = cls_reg_targets_init # target for refinement stage center_list, valid_flag_list = self.get_points( featmap_sizes, img_metas, device=device) pts_coordinate_preds_refine = self.offset_to_pts( center_list, pts_preds_refine) points_list = [] for i_img, center in enumerate(center_list): points = [] for i_lvl in range(len(pts_preds_refine)): points_preds_init_ = pts_preds_init[i_lvl].detach() points_preds_init_ = points_preds_init_.view( points_preds_init_.shape[0], -1, *points_preds_init_.shape[2:]) points_shift = points_preds_init_.permute( 0, 2, 3, 1) * self.point_strides[i_lvl] points_center = center[i_lvl][:, :2].repeat(1, self.num_points) points.append( points_center + points_shift[i_img].reshape(-1, 2 * self.num_points)) points_list.append(points) cls_reg_targets_refine = self.get_targets( points_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, stage='refine', label_channels=label_channels) (labels_list, label_weights_list, rbbox_gt_list_refine, candidate_list_refine, convex_weights_list_refine, num_total_pos_refine, num_total_neg_refine, gt_inds_refine, sam_weights_list_refine) = cls_reg_targets_refine num_total_samples_refine = ( num_total_pos_refine + num_total_neg_refine if self.sampling else num_total_pos_refine) losses_cls, losses_pts_init, losses_pts_refine = multi_apply( self.loss_single, cls_scores, pts_coordinate_preds_init, pts_coordinate_preds_refine, labels_list, label_weights_list, rbbox_gt_list_init, convex_weights_list_init, sam_weights_list_init, rbbox_gt_list_refine, convex_weights_list_refine, sam_weights_list_refine, self.point_strides, num_total_samples_refine=num_total_samples_refine) loss_dict_all = { 'loss_cls': losses_cls, 'loss_pts_init': losses_pts_init, 'loss_pts_refine': losses_pts_refine } return loss_dict_all
[docs] @force_fp32(apply_to=('cls_scores', 'pts_preds_init', 'pts_preds_refine')) def get_bboxes(self, cls_scores, pts_preds_init, pts_preds_refine, img_metas, cfg=None, rescale=False, with_nms=True, **kwargs): """Transform network outputs of a batch into bbox results. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). pts_preds_init (list[Tensor]): Box energies / deltas for all scale levels, each is a 18D-tensor, has shape (batch_size, num_points * 2, H, W). pts_preds_refine (list[Tensor]): Box energies / deltas for all scale levels, each is a 18D-tensor, has shape (batch_size, num_points * 2, H, W). img_metas (list[dict], Optional): Image meta info. Default None. cfg (mmcv.Config, Optional): Test / postprocessing configuration, if None, test_cfg would be used. Default None. rescale (bool): If True, return boxes in original image space. Default False. with_nms (bool): If True, do nms before return boxes. Default True. Returns: list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. The first item is an (n, 6) tensor, where the first 4 columns are bounding box positions (cx, cy, w, h, a) and the 6-th column is a score between 0 and 1. The second item is a (n,) tensor where each item is the predicted class label of the corresponding box. """ assert len(cls_scores) == len(pts_preds_refine) num_levels = len(cls_scores) featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].device, device=cls_scores[0].device) result_list = [] for img_id, _ in enumerate(img_metas): img_meta = img_metas[img_id] cls_score_list = select_single_mlvl(cls_scores, img_id) point_pred_list = select_single_mlvl(pts_preds_refine, img_id) results = self._get_bboxes_single(cls_score_list, point_pred_list, mlvl_priors, img_meta, cfg, rescale, with_nms, **kwargs) result_list.append(results) return result_list
def _get_bboxes_single(self, cls_score_list, point_pred_list, mlvl_priors, img_meta, cfg, rescale=False, with_nms=True, **kwargs): """Transform outputs of a single image into bbox predictions. Args: cls_score_list (list[Tensor]): Box scores from all scale levels of a single image, each item has shape (num_priors * num_classes, H, W). bbox_pred_list (list[Tensor]): Box energies / deltas from all scale levels of a single image, each item has shape (num_priors * 4, H, W). score_factor_list (list[Tensor]): Score factor from all scale levels of a single image. RepPoints head does not need this value. mlvl_priors (list[Tensor]): Each element in the list is the priors of a single level in feature pyramid, has shape (num_priors, 2). img_meta (dict): Image meta info. cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. with_nms (bool): If True, do nms before return boxes. Default: True. Returns: tuple[Tensor]: Results of detected bboxes and labels. If with_nms is False and mlvl_score_factor is None, return mlvl_bboxes and mlvl_scores, else return mlvl_bboxes, mlvl_scores and mlvl_score_factor. Usually with_nms is False is used for aug test. If with_nms is True, then return the following format - det_bboxes (Tensor): Predicted bboxes with shape \ [num_bboxes, 5], where the first 4 columns are bounding \ box positions (cx, cy, w, h, a) and the 5-th \ column are scores between 0 and 1. - det_labels (Tensor): Predicted labels of the corresponding \ box with shape [num_bboxes]. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_score_list) == len(point_pred_list) scale_factor = img_meta['scale_factor'] mlvl_bboxes = [] mlvl_scores = [] for level_idx, (cls_score, points_pred, points) in enumerate( zip(cls_score_list, point_pred_list, mlvl_priors)): assert cls_score.size()[-2:] == points_pred.size()[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1)[:, :-1] points_pred = points_pred.permute(1, 2, 0).reshape( -1, 2 * self.num_points) nms_pre = cfg.get('nms_pre', -1) if 0 < nms_pre < scores.shape[0]: if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: max_scores, _ = scores[:, 1:].max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] points_pred = points_pred[topk_inds, :] scores = scores[topk_inds, :] poly_pred = self.points2rotrect(points_pred, y_first=True) bbox_pos_center = points[:, :2].repeat(1, 4) polys = poly_pred * self.point_strides[level_idx] + bbox_pos_center bboxes = poly2obb(polys, self.version) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor( scale_factor) mlvl_scores = torch.cat(mlvl_scores) if self.use_sigmoid_cls: padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) if with_nms: det_bboxes, det_labels = multiclass_nms_rotated( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels else: raise NotImplementedError
Read the Docs v: stable
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.