Shortcuts

Source code for mmrotate.models.roi_heads.bbox_heads.rotated_bbox_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmcv.utils import to_2tuple
from mmdet.core import multi_apply
from mmdet.models.losses import accuracy
from mmdet.models.utils import build_linear_layer

from mmrotate.core import build_bbox_coder, multiclass_nms_rotated
from ...builder import ROTATED_HEADS, build_loss


[docs]@ROTATED_HEADS.register_module() class RotatedBBoxHead(BaseModule): """Simplest RoI head, with only two fc layers for classification and regression respectively. Args: with_avg_pool (bool, optional): If True, use ``avg_pool``. with_cls (bool, optional): If True, use classification branch. with_reg (bool, optional): If True, use regression branch. roi_feat_size (int, optional): Size of RoI features. in_channels (int, optional): Input channels. num_classes (int, optional): Number of classes. bbox_coder (dict, optional): Config of bbox coder. reg_class_agnostic (bool, optional): If True, regression branch are class agnostic. reg_decoded_bbox (bool, optional): If True, regression branch use decoded bbox to compute loss. reg_predictor_cfg (dict, optional): Config of regression predictor. cls_predictor_cfg (dict, optional): Config of classification predictor. loss_cls (dict, optional): Config of classification loss. loss_bbox (dict, optional): Config of regression loss. init_cfg (dict, optional): Config of initialization. """ def __init__(self, with_avg_pool=False, with_cls=True, with_reg=True, roi_feat_size=7, in_channels=256, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', clip_border=True, target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, reg_decoded_bbox=False, reg_predictor_cfg=dict(type='Linear'), cls_predictor_cfg=dict(type='Linear'), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict( type='SmoothL1Loss', beta=1.0, loss_weight=1.0), init_cfg=None): super(RotatedBBoxHead, self).__init__(init_cfg) assert with_cls or with_reg self.with_avg_pool = with_avg_pool self.with_cls = with_cls self.with_reg = with_reg self.roi_feat_size = to_2tuple(roi_feat_size) self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] self.in_channels = in_channels self.num_classes = num_classes self.reg_class_agnostic = reg_class_agnostic self.reg_decoded_bbox = reg_decoded_bbox self.reg_predictor_cfg = reg_predictor_cfg self.cls_predictor_cfg = cls_predictor_cfg self.fp16_enabled = False self.bbox_coder = build_bbox_coder(bbox_coder) self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) in_channels = self.in_channels if self.with_avg_pool: self.avg_pool = nn.AvgPool2d(self.roi_feat_size) else: in_channels *= self.roi_feat_area if self.with_cls: # need to add background class if self.custom_cls_channels: cls_channels = self.loss_cls.get_cls_channels(self.num_classes) else: cls_channels = num_classes + 1 self.fc_cls = build_linear_layer( self.cls_predictor_cfg, in_features=in_channels, out_features=cls_channels) if self.with_reg: out_dim_reg = 5 if reg_class_agnostic else 5 * num_classes self.fc_reg = build_linear_layer( self.reg_predictor_cfg, in_features=in_channels, out_features=out_dim_reg) self.debug_imgs = None if init_cfg is None: self.init_cfg = [] if self.with_cls: self.init_cfg += [ dict( type='Normal', std=0.01, override=dict(name='fc_cls')) ] if self.with_reg: self.init_cfg += [ dict( type='Normal', std=0.001, override=dict(name='fc_reg')) ] @property def custom_cls_channels(self): """The custom cls channels.""" return getattr(self.loss_cls, 'custom_cls_channels', False) @property def custom_activation(self): """The custom activation.""" return getattr(self.loss_cls, 'custom_activation', False) @property def custom_accuracy(self): """The custom accuracy.""" return getattr(self.loss_cls, 'custom_accuracy', False)
[docs] @auto_fp16() def forward(self, x): """Forward function of Rotated BBoxHead.""" if self.with_avg_pool: x = self.avg_pool(x) x = x.view(x.size(0), -1) cls_score = self.fc_cls(x) if self.with_cls else None bbox_pred = self.fc_reg(x) if self.with_reg else None return cls_score, bbox_pred
def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes, pos_gt_labels, cfg): """Calculate the ground truth for proposals in the single image according to the sampling results. Args: pos_bboxes (torch.Tensor): Contains all the positive boxes, has shape (num_pos, 5), the last dimension 5 represents [cx, cy, w, h, a]. neg_bboxes (torch.Tensor): Contains all the negative boxes, has shape (num_neg, 5), the last dimension 5 represents [cx, cy, w, h, a]. pos_gt_bboxes (torch.Tensor): Contains all the gt_boxes, has shape (num_gt, 5), the last dimension 5 represents [cx, cy, w, h, a]. pos_gt_labels (torch.Tensor): Contains all the gt_labels, has shape (num_gt). cfg (obj:`ConfigDict`): `train_cfg` of R-CNN. Returns: Tuple[Tensor]: Ground truth for proposals in a single image. Containing the following Tensors: - labels(torch.Tensor): Gt_labels for all proposals, has shape (num_proposals,). - label_weights(torch.Tensor): Labels_weights for all proposals, has shape (num_proposals,). - bbox_targets(torch.Tensor):Regression target for all proposals, has shape (num_proposals, 5), the last dimension 5 represents [cx, cy, w, h, a]. - bbox_weights(torch.Tensor):Regression weights for all proposals, has shape (num_proposals, 5). """ num_pos = pos_bboxes.size(0) num_neg = neg_bboxes.size(0) num_samples = num_pos + num_neg # original implementation uses new_zeros since BG are set to be 0 # now use empty & fill because BG cat_id = num_classes, # FG cat_id = [0, num_classes-1] labels = pos_bboxes.new_full((num_samples, ), self.num_classes, dtype=torch.long) label_weights = pos_bboxes.new_zeros(num_samples) bbox_targets = pos_bboxes.new_zeros(num_samples, 5) bbox_weights = pos_bboxes.new_zeros(num_samples, 5) if num_pos > 0: labels[:num_pos] = pos_gt_labels pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight label_weights[:num_pos] = pos_weight if not self.reg_decoded_bbox: pos_bbox_targets = self.bbox_coder.encode( pos_bboxes, pos_gt_bboxes) else: # When the regression loss (e.g. `IouLoss`, `GIouLoss`) # is applied directly on the decoded bounding boxes, both # the predicted boxes and regression targets should be with # absolute coordinate format. pos_bbox_targets = pos_gt_bboxes bbox_targets[:num_pos, :] = pos_bbox_targets bbox_weights[:num_pos, :] = 1 if num_neg > 0: label_weights[-num_neg:] = 1.0 return labels, label_weights, bbox_targets, bbox_weights
[docs] def get_targets(self, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg, concat=True): """Calculate the ground truth for all samples in a batch according to the sampling_results. Almost the same as the implementation in bbox_head, we passed additional parameters pos_inds_list and neg_inds_list to `_get_target_single` function. Args: sampling_results (List[obj:SamplingResults]): Assign results of all images in a batch after sampling. gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch, each tensor has shape (num_gt, 5), the last dimension 5 represents [cx, cy, w, h, a]. gt_labels (list[Tensor]): Gt_labels of all images in a batch, each tensor has shape (num_gt,). rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. concat (bool): Whether to concatenate the results of all the images in a single batch. Returns: Tuple[Tensor]: Ground truth for proposals in a single image. Containing the following list of Tensors: - labels (list[Tensor],Tensor): Gt_labels for all proposals in a batch, each tensor in list has shape (num_proposals,) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals,). - label_weights (list[Tensor]): Labels_weights for all proposals in a batch, each tensor in list has shape (num_proposals,) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals,). - bbox_targets (list[Tensor],Tensor): Regression target for all proposals in a batch, each tensor in list has shape (num_proposals, 5) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals, 5), the last dimension 4 represents [cx, cy, w, h, a]. - bbox_weights (list[tensor],Tensor): Regression weights for all proposals in a batch, each tensor in list has shape (num_proposals, 5) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals, 5). """ pos_bboxes_list = [res.pos_bboxes for res in sampling_results] neg_bboxes_list = [res.neg_bboxes for res in sampling_results] pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] labels, label_weights, bbox_targets, bbox_weights = multi_apply( self._get_target_single, pos_bboxes_list, neg_bboxes_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg=rcnn_train_cfg) if concat: labels = torch.cat(labels, 0) label_weights = torch.cat(label_weights, 0) bbox_targets = torch.cat(bbox_targets, 0) bbox_weights = torch.cat(bbox_weights, 0) return labels, label_weights, bbox_targets, bbox_weights
[docs] @force_fp32(apply_to=('cls_score', 'bbox_pred')) def loss(self, cls_score, bbox_pred, rois, labels, label_weights, bbox_targets, bbox_weights, reduction_override=None): """Loss function. Args: cls_score (torch.Tensor): Box scores, has shape (num_boxes, num_classes + 1). bbox_pred (Tensor, optional): Box energies / deltas. has shape (num_boxes, num_classes * 5). rois (torch.Tensor): Boxes to be transformed. Has shape (num_boxes, 5). last dimension 5 arrange as (batch_index, x1, y1, x2, y2). labels (torch.Tensor): Shape (n*bs, ). label_weights(torch.Tensor): Labels_weights for all proposals, has shape (num_proposals,). bbox_targets(torch.Tensor):Regression target for all proposals, has shape (num_proposals, 5), the last dimension 5 represents [cx, cy, w, h, a]. bbox_weights (list[tensor],Tensor): Regression weights for all proposals in a batch, each tensor in list has shape (num_proposals, 5) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals, 5). reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) if cls_score.numel() > 0: loss_cls_ = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, reduction_override=reduction_override) if isinstance(loss_cls_, dict): losses.update(loss_cls_) else: losses['loss_cls'] = loss_cls_ if self.custom_activation: acc_ = self.loss_cls.get_accuracy(cls_score, labels) losses.update(acc_) else: losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: bg_class_ind = self.num_classes # 0~self.num_classes-1 are FG, self.num_classes is BG pos_inds = (labels >= 0) & (labels < bg_class_ind) # do not perform bounding box regression for BG anymore. if pos_inds.any(): if self.reg_decoded_bbox: # When the regression loss (e.g. `IouLoss`, # `GIouLoss`, `DIouLoss`) is applied directly on # the decoded bounding boxes, it decodes the # already encoded coordinates to absolute format. bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) if self.reg_class_agnostic: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), 5)[pos_inds.type(torch.bool)] else: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), -1, 5)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], avg_factor=bbox_targets.size(0), reduction_override=reduction_override) else: losses['loss_bbox'] = bbox_pred[pos_inds].sum() return losses
[docs] @force_fp32(apply_to=('cls_score', 'bbox_pred')) def get_bboxes(self, rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=False, cfg=None): """Transform network output for a batch into bbox predictions. Args: rois (torch.Tensor): Boxes to be transformed. Has shape (num_boxes, 5). last dimension 5 arrange as (batch_index, x1, y1, x2, y2). cls_score (torch.Tensor): Box scores, has shape (num_boxes, num_classes + 1). bbox_pred (Tensor, optional): Box energies / deltas. has shape (num_boxes, num_classes * 5). img_shape (Sequence[int], optional): Maximum bounds for boxes, specifies (H, W, C) or (H, W). scale_factor (ndarray): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). rescale (bool): If True, return boxes in original image space. Default: False. cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None Returns: tuple[Tensor, Tensor]: First tensor is `det_bboxes`, has the shape (num_boxes, 6) and last dimension 6 represent (cx, cy, w, h, a, score). Second tensor is the labels with shape (num_boxes, ). """ # some loss (Seesaw loss..) may have custom activation if self.custom_cls_channels: scores = self.loss_cls.get_activation(cls_score) else: scores = F.softmax( cls_score, dim=-1) if cls_score is not None else None # bbox_pred would be None in some detector when with_reg is False, # e.g. Grid R-CNN. if bbox_pred is not None: bboxes = self.bbox_coder.decode( rois[..., 1:], bbox_pred, max_shape=img_shape) else: bboxes = rois[:, 1:].clone() if img_shape is not None: bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1]) bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0]) if rescale and bboxes.size(0) > 0: scale_factor = bboxes.new_tensor(scale_factor) bboxes = bboxes.view(bboxes.size(0), -1, 5) bboxes[..., :4] = bboxes[..., :4] / scale_factor bboxes = bboxes.view(bboxes.size(0), -1) if cfg is None: return bboxes, scores else: det_bboxes, det_labels = multiclass_nms_rotated( bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels
[docs] @force_fp32(apply_to=('bbox_preds', )) def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): """Refine bboxes during training. Args: rois (torch.Tensor): Shape (n*bs, 5), where n is image number per GPU, and bs is the sampled RoIs per image. The first column is the image id and the next 4 columns are x1, y1, x2, y2. labels (torch.Tensor): Shape (n*bs, ). bbox_preds (torch.Tensor): Shape (n*bs, 5) or (n*bs, 5*#class). pos_is_gts (list[Tensor]): Flags indicating if each positive bbox is a gt bbox. img_metas (list[dict]): Meta info of each image. Returns: list[Tensor]: Refined bboxes of each image in a mini-batch. """ img_ids = rois[:, 0].long().unique(sorted=True) assert img_ids.numel() <= len(img_metas) bboxes_list = [] for i, _ in enumerate(img_metas): inds = torch.nonzero( rois[:, 0] == i, as_tuple=False).squeeze(dim=1) num_rois = inds.numel() bboxes_ = rois[inds, 1:] label_ = labels[inds] bbox_pred_ = bbox_preds[inds] img_meta_ = img_metas[i] pos_is_gts_ = pos_is_gts[i] bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, img_meta_) # filter gt bboxes pos_keep = 1 - pos_is_gts_ keep_inds = pos_is_gts_.new_ones(num_rois) keep_inds[:len(pos_is_gts_)] = pos_keep bboxes_list.append(bboxes[keep_inds.type(torch.bool)]) return bboxes_list
[docs] @force_fp32(apply_to=('bbox_pred', )) def regress_by_class(self, rois, label, bbox_pred, img_meta): """Regress the bbox for the predicted class. Used in Cascade R-CNN. Args: rois (torch.Tensor): shape (n, 4) or (n, 5) label (torch.Tensor): shape (n, ) bbox_pred (torch.Tensor): shape (n, 5*(#class)) or (n, 5) img_meta (dict): Image meta info. Returns: Tensor: Regressed bboxes, the same shape as input rois. """ assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape) if not self.reg_class_agnostic: label = label * 4 inds = torch.stack((label, label + 1, label + 2, label + 3), 1) bbox_pred = torch.gather(bbox_pred, 1, inds) assert bbox_pred.size(1) == 5 if rois.size(1) == 4: new_rois = self.bbox_coder.decode( rois, bbox_pred, max_shape=img_meta['img_shape']) else: bboxes = self.bbox_coder.decode( rois[:, 1:], bbox_pred, max_shape=img_meta['img_shape']) new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) return new_rois
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.