Shortcuts

Source code for mmrotate.models.dense_heads.oriented_reppoints_head

# Copyright (c) OpenMMLab. All rights reserved.
import math

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2d, chamfer_distance, min_area_polygons
from mmcv.runner import force_fp32
from mmdet.core import images_to_levels, multi_apply, unmap
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.core.utils import select_single_mlvl
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead

from mmrotate.core import (build_assigner, build_sampler,
                           multiclass_nms_rotated, obb2poly, poly2obb)
from ..builder import ROTATED_HEADS, build_loss
from .utils import levels_to_images


def ChamferDistance2D(point_set_1,
                      point_set_2,
                      distance_weight=0.05,
                      eps=1e-12):
    """Compute the Chamfer distance between two point sets.

    Args:
        point_set_1 (torch.tensor): point set 1 with shape (N_pointsets,
                                    N_points, 2)
        point_set_2 (torch.tensor): point set 2 with shape (N_pointsets,
                                    N_points, 2)

    Returns:
        dist (torch.tensor): chamfer distance between two point sets
                             with shape (N_pointsets,)
    """
    assert point_set_1.dim() == point_set_2.dim()
    assert point_set_1.shape[-1] == point_set_2.shape[-1]
    assert point_set_1.dim() <= 3
    dist1, dist2, _, _ = chamfer_distance(point_set_1, point_set_2)
    dist1 = torch.sqrt(torch.clamp(dist1, eps))
    dist2 = torch.sqrt(torch.clamp(dist2, eps))
    dist = distance_weight * (dist1.mean(-1) + dist2.mean(-1)) / 2.0

    return dist


[docs]@ROTATED_HEADS.register_module() class OrientedRepPointsHead(BaseDenseHead): """Oriented RepPoints head -<https://arxiv.org/pdf/2105.11111v4.pdf>. The head contains initial and refined stages based on RepPoints. The initial stage regresses coarse point sets, and the refine stage further regresses the fine point sets. The APAA scheme based on the quality of point set samples in the paper is employed in refined stage. Args: num_classes (int): Number of classes. in_channels (int): Number of input channels. feat_channels (int): Number of feature channels. point_feat_channels (int, optional): Number of channels of points features. stacked_convs (int, optional): Number of stacked convolutions. num_points (int, optional): Number of points in points set. gradient_mul (float, optional): The multiplier to gradients from points refinement and recognition. point_strides (Iterable, optional): points strides. point_base_scale (int, optional): Bbox scale for assigning labels. conv_bias (str, optional): The bias of convolution. loss_cls (dict, optional): Config of classification loss. loss_bbox_init (dict, optional): Config of initial points loss. loss_bbox_refine (dict, optional): Config of points loss in refinement. conv_cfg (dict, optional): The config of convolution. norm_cfg (dict, optional): The config of normlization. train_cfg (dict, optional): The config of train. test_cfg (dict, optional): The config of test. center_init (bool, optional): Whether to use center point assignment. top_ratio (float, optional): Ratio of top high-quality point sets. Defaults to 0.4. init_qua_weight (float, optional): Quality weight of initial stage. ori_qua_weight (float, optional): Orientation quality weight. poc_qua_weight (float, optional): Point-wise correlation quality weight. version (str, optional): Angle representations. Defaults to 'oc'. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, num_classes, in_channels, feat_channels, point_feat_channels=256, stacked_convs=3, num_points=9, gradient_mul=0.1, point_strides=[8, 16, 32, 64, 128], point_base_scale=4, conv_bias='auto', loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), loss_bbox_init=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), loss_bbox_refine=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), loss_spatial_init=dict( type='SpatialBorderLoss', loss_weight=0.05), loss_spatial_refine=dict( type='SpatialBorderLoss', loss_weight=0.1), conv_cfg=None, norm_cfg=None, train_cfg=None, test_cfg=None, center_init=True, version='oc', top_ratio=0.4, init_qua_weight=0.2, ori_qua_weight=0.3, poc_qua_weight=0.1, init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='reppoints_cls_out', std=0.01, bias_prob=0.01)), **kwargs): super(OrientedRepPointsHead, self).__init__(init_cfg) self.num_points = num_points self.point_feat_channels = point_feat_channels self.center_init = center_init # we use deform conv to extract points features self.dcn_kernel = int(np.sqrt(num_points)) self.dcn_pad = int((self.dcn_kernel - 1) / 2) assert self.dcn_kernel * self.dcn_kernel == num_points, \ 'The points number should be a square number.' assert self.dcn_kernel % 2 == 1, \ 'The points number should be an odd square number.' dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64) dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) dcn_base_x = np.tile(dcn_base, self.dcn_kernel) dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( (-1)) self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) self.num_classes = num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs assert conv_bias == 'auto' or isinstance(conv_bias, bool) self.conv_bias = conv_bias self.loss_cls = build_loss(loss_cls) self.train_cfg = train_cfg self.test_cfg = test_cfg self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.fp16_enabled = False self.gradient_mul = gradient_mul self.point_base_scale = point_base_scale self.point_strides = point_strides self.prior_generator = MlvlPointGenerator( self.point_strides, offset=0.) self.num_base_priors = self.prior_generator.num_base_priors[0] self.sampling = loss_cls['type'] not in ['FocalLoss'] if self.train_cfg: self.init_assigner = build_assigner(self.train_cfg.init.assigner) self.refine_assigner = build_assigner( self.train_cfg.refine.assigner) # use PseudoSampler when sampling is False if self.sampling and hasattr(self.train_cfg, 'sampler'): sampler_cfg = self.train_cfg.sampler else: sampler_cfg = dict(type='PseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) if self.use_sigmoid_cls: self.cls_out_channels = self.num_classes else: self.cls_out_channels = self.num_classes + 1 self.loss_bbox_init = build_loss(loss_bbox_init) self.loss_bbox_refine = build_loss(loss_bbox_refine) self.loss_spatial_init = build_loss(loss_spatial_init) self.loss_spatial_refine = build_loss(loss_spatial_refine) self.init_qua_weight = init_qua_weight self.ori_qua_weight = ori_qua_weight self.poc_qua_weight = poc_qua_weight self.top_ratio = top_ratio self.version = version self._init_layers() def _init_layers(self): """Initialize layers of the head.""" self.relu = nn.ReLU(inplace=True) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) self.reg_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) pts_out_dim = 2 * self.num_points self.reppoints_cls_conv = DeformConv2d(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad) self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, self.cls_out_channels, 1, 1, 0) self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, self.point_feat_channels, 3, 1, 1) self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0) self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad) self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, pts_out_dim, 1, 1, 0)
[docs] def forward(self, feats): """Forward function.""" return multi_apply(self.forward_single, feats)
[docs] def forward_single(self, x): """Forward feature map of a single FPN level. Args: x (torch.tensor): single-level feature map sizes. Returns: cls_out (torch.tensor): classification score prediction pts_out_init (torch.tensor): initial point sets prediction pts_out_refine (torch.tensor): refined point sets prediction base_feat: single-level feature as the basic feature map """ dcn_base_offset = self.dcn_base_offset.type_as(x) points_init = 0 cls_feat = x pts_feat = x base_feat = x for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: pts_feat = reg_conv(pts_feat) # initialize reppoints pts_out_init = self.reppoints_pts_init_out( self.relu(self.reppoints_pts_init_conv(pts_feat))) pts_out_init = pts_out_init + points_init # refine and classify reppoints pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( ) + self.gradient_mul * pts_out_init dcn_offset = pts_out_init_grad_mul - dcn_base_offset cls_out = self.reppoints_cls_out( self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) pts_out_refine = self.reppoints_pts_refine_out( self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) pts_out_refine = pts_out_refine + pts_out_init.detach() return cls_out, pts_out_init, pts_out_refine, base_feat
[docs] def get_points(self, featmap_sizes, img_metas, device): """Get points according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. img_metas (list[dict]): Image meta info. Returns: tuple: points of each image, valid flags of each image """ num_imgs = len(img_metas) multi_level_points = self.prior_generator.grid_priors( featmap_sizes, device=device, with_stride=True) points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)] valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): multi_level_flags = self.prior_generator.valid_flags( featmap_sizes, img_meta['pad_shape']) valid_flag_list.append(multi_level_flags) return points_list, valid_flag_list
[docs] def offset_to_pts(self, center_list, pred_list): """Change from point offset to point coordinate.""" pts_list = [] for i_lvl, _ in enumerate(self.point_strides): pts_lvl = [] for i_img, _ in enumerate(center_list): pts_center = center_list[i_img][i_lvl][:, :2].repeat( 1, self.num_points) pts_shift = pred_list[i_lvl][i_img] yx_pts_shift = pts_shift.permute(1, 2, 0).view( -1, 2 * self.num_points) y_pts_shift = yx_pts_shift[..., 0::2] x_pts_shift = yx_pts_shift[..., 1::2] xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1) xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center pts_lvl.append(pts) pts_lvl = torch.stack(pts_lvl, 0) pts_list.append(pts_lvl) return pts_list
[docs] def sampling_points(self, polygons, points_num, device): """Sample edge points for polygon. Args: polygons (torch.tensor): polygons with shape (N, 8) points_num (int): number of sampling points for each polygon edge. 10 by default. Returns: sampling_points (torch.tensor): sampling points with shape (N, points_num*4, 2) """ polygons_xs, polygons_ys = polygons[:, 0::2], polygons[:, 1::2] ratio = torch.linspace(0, 1, points_num).to(device).repeat( polygons.shape[0], 1) edge_pts_x = [] edge_pts_y = [] for i in range(4): if i < 3: points_x = ratio * polygons_xs[:, i + 1:i + 2] + ( 1 - ratio) * polygons_xs[:, i:i + 1] points_y = ratio * polygons_ys[:, i + 1:i + 2] + ( 1 - ratio) * polygons_ys[:, i:i + 1] else: points_x = ratio * polygons_xs[:, 0].unsqueeze(1) + ( 1 - ratio) * polygons_xs[:, i].unsqueeze(1) points_y = ratio * polygons_ys[:, 0].unsqueeze(1) + ( 1 - ratio) * polygons_ys[:, i].unsqueeze(1) edge_pts_x.append(points_x) edge_pts_y.append(points_y) sampling_points_x = torch.cat(edge_pts_x, dim=1).unsqueeze(dim=2) sampling_points_y = torch.cat(edge_pts_y, dim=1).unsqueeze(dim=2) sampling_points = torch.cat([sampling_points_x, sampling_points_y], dim=2) return sampling_points
[docs] def get_adaptive_points_feature(self, features, pt_locations, stride): """Get the points features from the locations of predicted points. Args: features (torch.tensor): base feature with shape (B,C,W,H) pt_locations (torch.tensor): locations of points in each point set with shape (B, N_points_set(number of point set), N_points(number of points in each point set) *2) Returns: tensor: sampling features with (B, C, N_points_set, N_points) """ h = features.shape[2] * stride w = features.shape[3] * stride pt_locations = pt_locations.view(pt_locations.shape[0], pt_locations.shape[1], -1, 2).clone() pt_locations[..., 0] = pt_locations[..., 0] / (w / 2.) - 1 pt_locations[..., 1] = pt_locations[..., 1] / (h / 2.) - 1 batch_size = features.size(0) sampled_features = torch.zeros([ pt_locations.shape[0], features.size(1), pt_locations.size(1), pt_locations.size(2) ]).to(pt_locations.device) for i in range(batch_size): feature = nn.functional.grid_sample(features[i:i + 1], pt_locations[i:i + 1])[0] sampled_features[i] = feature return sampled_features,
[docs] def feature_cosine_similarity(self, points_features): """Compute the points features similarity for points-wise correlation. Args: points_features (torch.tensor): sampling point feature with shape (N_pointsets, N_points, C) Returns: max_correlation: max feature similarity in each point set with shape (N_points_set, N_points, C) """ mean_points_feats = torch.mean(points_features, dim=1, keepdim=True) norm_pts_feats = torch.norm( points_features, p=2, dim=2).unsqueeze(dim=2).clamp(min=1e-2) norm_mean_pts_feats = torch.norm( mean_points_feats, p=2, dim=2).unsqueeze(dim=2).clamp(min=1e-2) unity_points_features = points_features / norm_pts_feats unity_mean_points_feats = mean_points_feats / norm_mean_pts_feats cos_similarity = nn.CosineSimilarity(dim=2, eps=1e-6) feats_similarity = 1.0 - cos_similarity(unity_points_features, unity_mean_points_feats) max_correlation, _ = torch.max(feats_similarity, dim=1) return max_correlation
[docs] def pointsets_quality_assessment(self, pts_features, cls_score, pts_pred_init, pts_pred_refine, label, bbox_gt, label_weight, bbox_weight, pos_inds): """Assess the quality of each point set from the classification, localization, orientation, and point-wise correlation based on the assigned point sets samples. Args: pts_features (torch.tensor): points features with shape (N, 9, C) cls_score (torch.tensor): classification scores with shape (N, class_num) pts_pred_init (torch.tensor): initial point sets prediction with shape (N, 9*2) pts_pred_refine (torch.tensor): refined point sets prediction with shape (N, 9*2) label (torch.tensor): gt label with shape (N) bbox_gt(torch.tensor): gt bbox of polygon with shape (N, 8) label_weight (torch.tensor): label weight with shape (N) bbox_weight (torch.tensor): box weight with shape (N) pos_inds (torch.tensor): the inds of positive point set samples Returns: qua (torch.tensor) : weighted quality values for positive point set samples. """ device = cls_score.device pos_scores = cls_score[pos_inds] pos_pts_pred_init = pts_pred_init[pos_inds] pos_pts_pred_refine = pts_pred_refine[pos_inds] pos_pts_refine_features = pts_features[pos_inds] pos_bbox_gt = bbox_gt[pos_inds] pos_label = label[pos_inds] pos_label_weight = label_weight[pos_inds] pos_bbox_weight = bbox_weight[pos_inds] # quality of point-wise correlation qua_poc = self.poc_qua_weight * self.feature_cosine_similarity( pos_pts_refine_features) qua_cls = self.loss_cls( pos_scores, pos_label, pos_label_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') polygons_pred_init = min_area_polygons(pos_pts_pred_init) polygons_pred_refine = min_area_polygons(pos_pts_pred_refine) sampling_pts_pred_init = self.sampling_points( polygons_pred_init, 10, device=device) sampling_pts_pred_refine = self.sampling_points( polygons_pred_refine, 10, device=device) sampling_pts_gt = self.sampling_points(pos_bbox_gt, 10, device=device) # quality of orientation qua_ori_init = self.ori_qua_weight * ChamferDistance2D( sampling_pts_gt, sampling_pts_pred_init) qua_ori_refine = self.ori_qua_weight * ChamferDistance2D( sampling_pts_gt, sampling_pts_pred_refine) # quality of localization qua_loc_init = self.loss_bbox_refine( pos_pts_pred_init, pos_bbox_gt, pos_bbox_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') qua_loc_refine = self.loss_bbox_refine( pos_pts_pred_refine, pos_bbox_gt, pos_bbox_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') # quality of classification qua_cls = qua_cls.sum(-1) # weighted inti-stage and refine-stage qua = qua_cls + self.init_qua_weight * ( qua_loc_init + qua_ori_init) + (1.0 - self.init_qua_weight) * ( qua_loc_refine + qua_ori_refine) + qua_poc return qua,
[docs] def dynamic_pointset_samples_selection(self, quality, label, label_weight, bbox_weight, pos_inds, pos_gt_inds, num_proposals_each_level=None, num_level=None): """The dynamic top k selection of point set samples based on the quality assessment values. Args: quality (torch.tensor): the quality values of positive point set samples label (torch.tensor): gt label with shape (N) bbox_gt(torch.tensor): gt bbox of polygon with shape (N, 8) label_weight (torch.tensor): label weight with shape (N) bbox_weight (torch.tensor): box weight with shape (N) pos_inds (torch.tensor): the inds of positive point set samples num_proposals_each_level (list[int]): proposals number of each level num_level (int): the level number Returns: label: gt label with shape (N) label_weight: label weight with shape (N) bbox_weight: box weight with shape (N) num_pos (int): the number of selected positive point samples with high-qualty pos_normalize_term (torch.tensor): the corresponding positive normalize term """ if len(pos_inds) == 0: return label, label_weight, bbox_weight, 0, torch.tensor( []).type_as(bbox_weight) num_gt = pos_gt_inds.max() num_proposals_each_level_ = num_proposals_each_level.copy() num_proposals_each_level_.insert(0, 0) inds_level_interval = np.cumsum(num_proposals_each_level_) pos_level_mask = [] for i in range(num_level): mask = (pos_inds >= inds_level_interval[i]) & ( pos_inds < inds_level_interval[i + 1]) pos_level_mask.append(mask) pos_inds_after_select = [] ignore_inds_after_select = [] for gt_ind in range(num_gt): pos_inds_select = [] pos_loss_select = [] gt_mask = pos_gt_inds == (gt_ind + 1) for level in range(num_level): level_mask = pos_level_mask[level] level_gt_mask = level_mask & gt_mask value, topk_inds = quality[level_gt_mask].topk( min(level_gt_mask.sum(), 6), largest=False) pos_inds_select.append(pos_inds[level_gt_mask][topk_inds]) pos_loss_select.append(value) pos_inds_select = torch.cat(pos_inds_select) pos_loss_select = torch.cat(pos_loss_select) if len(pos_inds_select) < 2: pos_inds_after_select.append(pos_inds_select) ignore_inds_after_select.append(pos_inds_select.new_tensor([])) else: pos_loss_select, sort_inds = pos_loss_select.sort( ) # small to large pos_inds_select = pos_inds_select[sort_inds] # dynamic top k topk = math.ceil(pos_loss_select.shape[0] * self.top_ratio) pos_inds_select_topk = pos_inds_select[:topk] pos_inds_after_select.append(pos_inds_select_topk) ignore_inds_after_select.append( pos_inds_select_topk.new_tensor([])) pos_inds_after_select = torch.cat(pos_inds_after_select) ignore_inds_after_select = torch.cat(ignore_inds_after_select) reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_select).all(1) reassign_ids = pos_inds[reassign_mask] label[reassign_ids] = self.num_classes label_weight[ignore_inds_after_select] = 0 bbox_weight[reassign_ids] = 0 num_pos = len(pos_inds_after_select) pos_level_mask_after_select = [] for i in range(num_level): mask = (pos_inds_after_select >= inds_level_interval[i]) & ( pos_inds_after_select < inds_level_interval[i + 1]) pos_level_mask_after_select.append(mask) pos_level_mask_after_select = torch.stack(pos_level_mask_after_select, 0).type_as(label) pos_normalize_term = pos_level_mask_after_select * ( self.point_base_scale * torch.as_tensor(self.point_strides).type_as(label)).reshape(-1, 1) pos_normalize_term = pos_normalize_term[ pos_normalize_term > 0].type_as(bbox_weight) assert len(pos_normalize_term) == len(pos_inds_after_select) return label, label_weight, bbox_weight, num_pos, pos_normalize_term
[docs] def init_loss_single(self, pts_pred_init, bbox_gt_init, bbox_weights_init, stride): """Single initial stage loss function.""" normalize_term = self.point_base_scale * stride bbox_gt_init = bbox_gt_init.reshape(-1, 8) bbox_weights_init = bbox_weights_init.reshape(-1) pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) pos_ind_init = (bbox_weights_init > 0).nonzero( as_tuple=False).reshape(-1) pts_pred_init_norm = pts_pred_init[pos_ind_init] bbox_gt_init_norm = bbox_gt_init[pos_ind_init] bbox_weights_pos_init = bbox_weights_init[pos_ind_init] loss_pts_init = self.loss_bbox_init( pts_pred_init_norm / normalize_term, bbox_gt_init_norm / normalize_term, bbox_weights_pos_init) loss_border_init = self.loss_spatial_init( pts_pred_init_norm.reshape(-1, 2 * self.num_points) / normalize_term, bbox_gt_init_norm / normalize_term, bbox_weights_pos_init, avg_factor=None) return loss_pts_init, loss_border_init
def _point_target_single(self, flat_proposals, valid_flags, gt_bboxes, gt_bboxes_ignore, gt_labels, overlaps, stage='init', unmap_outputs=True): """Single point target function for initial and refine stage.""" inside_flags = valid_flags if not inside_flags.any(): return (None, ) * 8 # assign gt and sample proposals proposals = flat_proposals[inside_flags, :] if stage == 'init': assigner = self.init_assigner pos_weight = self.train_cfg.init.pos_weight else: assigner = self.refine_assigner pos_weight = self.train_cfg.refine.pos_weight # convert gt from obb to poly gt_bboxes = obb2poly(gt_bboxes, self.version) assign_result = assigner.assign(proposals, gt_bboxes, overlaps, gt_bboxes_ignore, None if self.sampling else gt_labels) sampling_result = self.sampler.sample(assign_result, proposals, gt_bboxes) gt_inds = assign_result.gt_inds num_valid_proposals = proposals.shape[0] bbox_gt = proposals.new_zeros([num_valid_proposals, 8]) pos_proposals = torch.zeros_like(proposals) proposals_weights = proposals.new_zeros(num_valid_proposals) labels = proposals.new_full((num_valid_proposals, ), self.num_classes, dtype=torch.long) label_weights = proposals.new_zeros( num_valid_proposals, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: pos_gt_bboxes = sampling_result.pos_gt_bboxes bbox_gt[pos_inds, :] = pos_gt_bboxes pos_proposals[pos_inds, :] = proposals[pos_inds, :] proposals_weights[pos_inds] = 1.0 if gt_labels is None: # Only rpn gives gt_labels as None # Foreground is the first class labels[pos_inds] = 0 else: labels[pos_inds] = gt_labels[ sampling_result.pos_assigned_gt_inds] if pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) labels = unmap(labels, num_total_proposals, inside_flags) label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags) proposals_weights = unmap(proposals_weights, num_total_proposals, inside_flags) gt_inds = unmap(gt_inds, num_total_proposals, inside_flags) return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, pos_inds, neg_inds, gt_inds, sampling_result)
[docs] def get_targets(self, proposals_list, valid_flag_list, gt_bboxes_list, img_metas, gt_bboxes_ignore_list=None, gt_labels_list=None, stage='init', label_channels=1, unmap_outputs=True): """Compute corresponding GT box and classification targets for proposals in initial stage. Args: proposals_list (list[list]): Multi level points/bboxes of each image. valid_flag_list (list[list]): Multi level valid flags of each image. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. img_metas (list[dict]): Meta info of each image. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. gt_bboxes_list (list[Tensor]): Ground truth labels of each box. stage (str): `init` or `refine`. Generate target for init stage or refine stage label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple (list[Tensor]): - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of each \ level. - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. - proposal_list (list[Tensor]): Proposals(points/bboxes) of \ each level. - proposal_weights_list (list[Tensor]): Proposal weights of \ each level. - num_total_pos (int): Number of positive samples in all \ images. - num_total_neg (int): Number of negative samples in all \ images. """ assert stage in ['init', 'refine'] num_imgs = len(img_metas) assert len(proposals_list) == len(valid_flag_list) == num_imgs # points number of multi levels num_level_proposals = [points.size(0) for points in proposals_list[0]] # concat all level points and flags to a single tensor for i in range(num_imgs): assert len(proposals_list[i]) == len(valid_flag_list[i]) proposals_list[i] = torch.cat(proposals_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i]) # compute targets for each image if gt_bboxes_ignore_list is None: gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] all_overlaps_rotate_list = [None] * 4 (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds_list, neg_inds_list, all_gt_inds, sampling_result) = multi_apply( self._point_target_single, proposals_list, valid_flag_list, gt_bboxes_list, gt_bboxes_ignore_list, gt_labels_list, all_overlaps_rotate_list, stage=stage, unmap_outputs=unmap_outputs) if stage == 'init': # no valid points if any([labels is None for labels in all_labels]): return None # sampled points of all images num_total_pos = sum( [max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum( [max(inds.numel(), 1) for inds in neg_inds_list]) labels_list = images_to_levels(all_labels, num_level_proposals) label_weights_list = images_to_levels(all_label_weights, num_level_proposals) bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) proposals_list = images_to_levels(all_proposals, num_level_proposals) proposal_weights_list = images_to_levels(all_proposal_weights, num_level_proposals) return (labels_list, label_weights_list, bbox_gt_list, proposals_list, proposal_weights_list, num_total_pos, num_total_neg, None) else: pos_inds = [] pos_gt_index = [] for i, single_labels in enumerate(all_labels): pos_mask = (0 <= single_labels) & ( single_labels < self.num_classes) pos_inds.append(pos_mask.nonzero(as_tuple=False).view(-1)) pos_gt_index.append( all_gt_inds[i][pos_mask.nonzero(as_tuple=False).view(-1)]) return (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds, pos_gt_index)
[docs] def loss(self, cls_scores, pts_preds_init, pts_preds_refine, base_features, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Loss function of OrientedRepPoints head.""" featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 device = cls_scores[0].device center_list, valid_flag_list = self.get_points( featmap_sizes, img_metas, device=device) pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init) num_proposals_each_level = [(featmap.size(-1) * featmap.size(-2)) for featmap in cls_scores] num_level = len(featmap_sizes) assert num_level == len(pts_coordinate_preds_init) candidate_list = center_list cls_reg_targets_init = self.get_targets( candidate_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, stage='init', label_channels=label_channels) (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, num_total_pos_init, num_total_neg_init, _) = cls_reg_targets_init center_list, valid_flag_list = self.get_points( featmap_sizes, img_metas, device=device) pts_coordinate_preds_refine = self.offset_to_pts( center_list, pts_preds_refine) refine_points_features, = multi_apply(self.get_adaptive_points_feature, base_features, pts_coordinate_preds_refine, self.point_strides) features_pts_refine = levels_to_images(refine_points_features) features_pts_refine = [ item.reshape(-1, self.num_points, item.shape[-1]) for item in features_pts_refine ] points_list = [] for i_img, center in enumerate(center_list): points = [] for i_lvl in range(len(pts_preds_refine)): points_preds_init_ = pts_preds_init[i_lvl].detach() points_preds_init_ = points_preds_init_.view( points_preds_init_.shape[0], -1, *points_preds_init_.shape[2:]) points_shift = points_preds_init_.permute( 0, 2, 3, 1) * self.point_strides[i_lvl] points_center = center[i_lvl][:, :2].repeat(1, self.num_points) points.append( points_center + points_shift[i_img].reshape(-1, 2 * self.num_points)) points_list.append(points) cls_reg_targets_refine = self.get_targets( points_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, stage='refine', label_channels=label_channels) (labels_list, label_weights_list, bbox_gt_list_refine, _, bbox_weights_list_refine, pos_inds_list_refine, pos_gt_index_list_refine) = cls_reg_targets_refine cls_scores = levels_to_images(cls_scores) cls_scores = [ item.reshape(-1, self.cls_out_channels) for item in cls_scores ] pts_coordinate_preds_init_img = levels_to_images( pts_coordinate_preds_init, flatten=True) pts_coordinate_preds_init_img = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_init_img ] pts_coordinate_preds_refine_img = levels_to_images( pts_coordinate_preds_refine, flatten=True) pts_coordinate_preds_refine_img = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_refine_img ] with torch.no_grad(): quality_assess_list, = multi_apply( self.pointsets_quality_assessment, features_pts_refine, cls_scores, pts_coordinate_preds_init_img, pts_coordinate_preds_refine_img, labels_list, bbox_gt_list_refine, label_weights_list, bbox_weights_list_refine, pos_inds_list_refine) labels_list, label_weights_list, bbox_weights_list_refine, \ num_pos, pos_normalize_term = multi_apply( self.dynamic_pointset_samples_selection, quality_assess_list, labels_list, label_weights_list, bbox_weights_list_refine, pos_inds_list_refine, pos_gt_index_list_refine, num_proposals_each_level=num_proposals_each_level, num_level=num_level ) num_pos = sum(num_pos) # convert all tensor list to a flatten tensor cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) pts_preds_refine = torch.cat(pts_coordinate_preds_refine_img, 0).view( -1, pts_coordinate_preds_refine_img[0].size(-1)) labels = torch.cat(labels_list, 0).view(-1) labels_weight = torch.cat(label_weights_list, 0).view(-1) bbox_gt_refine = torch.cat(bbox_gt_list_refine, 0).view(-1, bbox_gt_list_refine[0].size(-1)) bbox_weights_refine = torch.cat(bbox_weights_list_refine, 0).view(-1) pos_normalize_term = torch.cat(pos_normalize_term, 0).reshape(-1) pos_inds_flatten = ((0 <= labels) & (labels < self.num_classes)).nonzero( as_tuple=False).reshape(-1) assert len(pos_normalize_term) == len(pos_inds_flatten) if num_pos: losses_cls = self.loss_cls( cls_scores, labels, labels_weight, avg_factor=num_pos) pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten] pos_bbox_gt_refine = bbox_gt_refine[pos_inds_flatten] pos_bbox_weights_refine = bbox_weights_refine[pos_inds_flatten] losses_pts_refine = self.loss_bbox_refine( pos_pts_pred_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_weights_refine) loss_border_refine = self.loss_spatial_refine( pos_pts_pred_refine.reshape(-1, 2 * self.num_points) / pos_normalize_term.reshape(-1, 1), pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_weights_refine, avg_factor=None) else: losses_cls = cls_scores.sum() * 0 losses_pts_refine = pts_preds_refine.sum() * 0 loss_border_refine = pts_preds_refine.sum() * 0 losses_pts_init, loss_border_init = multi_apply( self.init_loss_single, pts_coordinate_preds_init, bbox_gt_list_init, bbox_weights_list_init, self.point_strides) loss_dict_all = { 'loss_cls': losses_cls, 'loss_pts_init': losses_pts_init, 'loss_pts_refine': losses_pts_refine, 'loss_spatial_init': loss_border_init, 'loss_spatial_refine': loss_border_refine } return loss_dict_all
[docs] @force_fp32(apply_to=('cls_scores', 'pts_preds_init', 'pts_preds_refine')) def get_bboxes(self, cls_scores, pts_preds_init, pts_preds_refine, base_feats, img_metas, cfg=None, rescale=False, with_nms=True, **kwargs): """Transform network outputs of a batch into bbox results. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). pts_preds_init (list[Tensor]): Box energies / deltas for all scale levels, each is a 18D-tensor, has shape (batch_size, num_points * 2, H, W). pts_preds_refine (list[Tensor]): Box energies / deltas for all scale levels, each is a 18D-tensor, has shape (batch_size, num_points * 2, H, W). img_metas (list[dict], Optional): Image meta info. Default None. cfg (mmcv.Config, Optional): Test / postprocessing configuration, if None, test_cfg would be used. Default None. rescale (bool): If True, return boxes in original image space. Default False. with_nms (bool): If True, do nms before return boxes. Default True. Returns: list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. The first item is an (n, 6) tensor, where the first 4 columns are bounding box positions (cx, cy, w, h, a) and the 6-th column is a score between 0 and 1. The second item is a (n,) tensor where each item is the predicted class label of the corresponding box. """ assert len(cls_scores) == len(pts_preds_refine) num_levels = len(cls_scores) featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].device, device=cls_scores[0].device) result_list = [] for img_id, _ in enumerate(img_metas): img_meta = img_metas[img_id] cls_score_list = select_single_mlvl(cls_scores, img_id) point_pred_list = select_single_mlvl(pts_preds_refine, img_id) results = self._get_bboxes_single(cls_score_list, point_pred_list, mlvl_priors, img_meta, cfg, rescale, with_nms, **kwargs) result_list.append(results) return result_list
def _get_bboxes_single(self, cls_score_list, point_pred_list, mlvl_priors, img_meta, cfg, rescale=False, with_nms=True, **kwargs): """Transform outputs of a single image into bbox predictions. Args: cls_score_list (list[Tensor]): Box scores from all scale levels of a single image, each item has shape (num_priors * num_classes, H, W). bbox_pred_list (list[Tensor]): Box energies / deltas from all scale levels of a single image, each item has shape (num_priors * 4, H, W). score_factor_list (list[Tensor]): Score factor from all scale levels of a single image. RepPoints head does not need this value. mlvl_priors (list[Tensor]): Each element in the list is the priors of a single level in feature pyramid, has shape (num_priors, 2). img_meta (dict): Image meta info. cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. with_nms (bool): If True, do nms before return boxes. Default: True. Returns: tuple[Tensor]: Results of detected bboxes and labels. If with_nms is False and mlvl_score_factor is None, return mlvl_bboxes and mlvl_scores, else return mlvl_bboxes, mlvl_scores and mlvl_score_factor. Usually with_nms is False is used for aug test. If with_nms is True, then return the following format - det_bboxes (Tensor): Predicted bboxes with shape \ [num_bboxes, 5], where the first 4 columns are bounding \ box positions (cx, cy, w, h, a) and the 5-th \ column are scores between 0 and 1. - det_labels (Tensor): Predicted labels of the corresponding \ box with shape [num_bboxes]. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_score_list) == len(point_pred_list) scale_factor = img_meta['scale_factor'] mlvl_bboxes = [] mlvl_scores = [] for level_idx, (cls_score, points_pred, points) in enumerate( zip(cls_score_list, point_pred_list, mlvl_priors)): assert cls_score.size()[-2:] == points_pred.size()[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1)[:, :-1] points_pred = points_pred.permute(1, 2, 0).reshape( -1, 2 * self.num_points) nms_pre = cfg.get('nms_pre', -1) if 0 < nms_pre < scores.shape[0]: if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: max_scores, _ = scores[:, 1:].max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] points_pred = points_pred[topk_inds, :] scores = scores[topk_inds, :] pts_pred = points_pred.reshape(-1, self.num_points, 2) pts_pred_offsety = pts_pred[:, :, 0::2] pts_pred_offsetx = pts_pred[:, :, 1::2] pts_pred = torch.cat([pts_pred_offsetx, pts_pred_offsety], dim=2).reshape(-1, 2 * self.num_points) pts_pos_center = points[:, :2].repeat(1, self.num_points) pts = pts_pred * self.point_strides[level_idx] + pts_pos_center polys = min_area_polygons(pts) bboxes = poly2obb(polys, self.version) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor( scale_factor) mlvl_scores = torch.cat(mlvl_scores) if self.use_sigmoid_cls: padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) if with_nms: det_bboxes, det_labels = multiclass_nms_rotated( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels else: raise NotImplementedError
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.