
Source code for mmrotate.models.dense_heads.odm_refine_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
from ..utils import ORConv2d, RotationInvariantPooling
from .rotated_retina_head import RotatedRetinaHead

[docs]@ROTATED_HEADS.register_module() class ODMRefineHead(RotatedRetinaHead): """Rotated Anchor-based refine head. It's a part of the Oriented Detection Module (ODM), which produces orientation-sensitive features for classification and orientation-invariant features for localization. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. stacked_convs (int, optional): Number of stacked convolutions. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. anchor_generator (dict): Config dict for anchor generator init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, stacked_convs=2, conv_cfg=None, norm_cfg=None, anchor_generator=dict( type='PseudoAnchorGenerator', strides=[8, 16, 32, 64, 128]), init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='odm_cls', std=0.01, bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None self.stacked_convs = stacked_convs self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg super(ODMRefineHead, self).__init__( num_classes, in_channels, stacked_convs=2, anchor_generator=anchor_generator, init_cfg=init_cfg, **kwargs) def _init_layers(self): """Initialize layers of the head.""" self.or_conv = ORConv2d( self.feat_channels, int(self.feat_channels / 8), kernel_size=3, padding=1, arf_config=(1, 8)) self.or_pool = RotationInvariantPooling(256, 8) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = int(self.feat_channels / 8) if i == 0 else self.feat_channels self.reg_convs.append( ConvModule( self.feat_channels, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.odm_cls = nn.Conv2d( self.feat_channels, self.num_anchors * self.cls_out_channels, 3, padding=1) self.odm_reg = nn.Conv2d( self.feat_channels, self.num_anchors * 5, 3, padding=1)
[docs] def forward_single(self, x): """Forward feature of a single scale level. Args: x (torch.Tensor): Features of a single scale level. Returns: tuple (torch.Tensor): - cls_score (torch.Tensor): Cls scores for a single scale \ level the channels number is num_anchors * num_classes. - bbox_pred (torch.Tensor): Box energies / deltas for a \ single scale level, the channels number is num_anchors * 4. """ or_feat = self.or_conv(x) reg_feat = or_feat cls_feat = self.or_pool(or_feat) for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: reg_feat = reg_conv(reg_feat) cls_score = self.odm_cls(cls_feat) bbox_pred = self.odm_reg(reg_feat) return cls_score, bbox_pred
[docs] def get_anchors(self, featmap_sizes, img_metas, device='cuda'): """Get anchors according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. img_metas (list[dict]): Image meta info. bboxes_as_anchors (list[list[Tensor]]) bboxes of levels of images. before further regression just like anchors. device (torch.device | str): Device for returned tensors Returns: tuple (list[Tensor]): - anchor_list (list[Tensor]): Anchors of each image - valid_flag_list (list[Tensor]): Valid flags of each image """ anchor_list = [[ bboxes_img_lvl.clone().detach() for bboxes_img_lvl in bboxes_img ] for bboxes_img in self.bboxes_as_anchors] # for each image, we compute valid flags of multi level anchors valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): multi_level_flags = self.anchor_generator.valid_flags( featmap_sizes, img_meta['pad_shape'], device) valid_flag_list.append(multi_level_flags) return anchor_list, valid_flag_list
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, rois=None, gt_bboxes_ignore=None): """Loss function of ODMRefineHead.""" assert rois is not None self.bboxes_as_anchors = rois return super(ODMRefineHead, self).loss( cls_scores=cls_scores, bbox_preds=bbox_preds, gt_bboxes=gt_bboxes, gt_labels=gt_labels, img_metas=img_metas, gt_bboxes_ignore=gt_bboxes_ignore)
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg=None, rescale=False, rois=None): """Transform network output for a batch into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 5, H, W) img_metas (list[dict]): size / scale info for each image cfg (mmcv.Config): test / postprocessing configuration rois (list[list[Tensor]]): input rbboxes of each level of each image. rois output by former stages and are to be refined rescale (bool): if True, return boxes in original image space Returns: list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple. The first item is an (n, 6) tensor, where the first 5 columns are bounding box positions (xc, yc, w, h, a) and the 6-th column is a score between 0 and 1. The second item is a (n,) tensor where each item is the class index of the corresponding box. """ num_levels = len(cls_scores) assert len(cls_scores) == len(bbox_preds) assert rois is not None result_list = [] for img_id, _ in enumerate(img_metas): cls_score_list = [ cls_scores[i][img_id].detach() for i in range(num_levels) ] bbox_pred_list = [ bbox_preds[i][img_id].detach() for i in range(num_levels) ] img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, rois[img_id], img_shape, scale_factor, cfg, rescale) result_list.append(proposals) return result_list
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.