Shortcuts

Source code for mmrotate.models.dense_heads.rotated_retina_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
from .rotated_anchor_head import RotatedAnchorHead


[docs]@ROTATED_HEADS.register_module() class RotatedRetinaHead(RotatedAnchorHead): r"""An anchor-based head used in `RotatedRetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_. The head contains two subnetworks. The first classifies anchor boxes and the second regresses deltas for the anchors. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. stacked_convs (int, optional): Number of stacked convolutions. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. anchor_generator (dict): Config dict for anchor generator init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, stacked_convs=4, conv_cfg=None, norm_cfg=None, anchor_generator=dict( type='AnchorGenerator', octave_base_scale=4, scales_per_octave=3, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128]), init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='retina_cls', std=0.01, bias_prob=0.01)), **kwargs): self.stacked_convs = stacked_convs self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg super(RotatedRetinaHead, self).__init__( num_classes, in_channels, anchor_generator=anchor_generator, init_cfg=init_cfg, **kwargs) def _init_layers(self): """Initialize layers of the head.""" self.relu = nn.ReLU(inplace=True) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.reg_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.retina_cls = nn.Conv2d( self.feat_channels, self.num_anchors * self.cls_out_channels, 3, padding=1) self.retina_reg = nn.Conv2d( self.feat_channels, self.num_anchors * 5, 3, padding=1)
[docs] def forward_single(self, x): """Forward feature of a single scale level. Args: x (torch.Tensor): Features of a single scale level. Returns: tuple (torch.Tensor): - cls_score (torch.Tensor): Cls scores for a single scale \ level the channels number is num_anchors * num_classes. - bbox_pred (torch.Tensor): Box energies / deltas for a \ single scale level, the channels number is num_anchors * 5. """ cls_feat = x reg_feat = x for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: reg_feat = reg_conv(reg_feat) cls_score = self.retina_cls(cls_feat) bbox_pred = self.retina_reg(reg_feat) return cls_score, bbox_pred
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def filter_bboxes(self, cls_scores, bbox_preds): """Filter predicted bounding boxes at each position of the feature maps. Only one bounding boxes with highest score will be left at each position. This filter will be used in R3Det prior to the first feature refinement stage. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 5, H, W) Returns: list[list[Tensor]]: best or refined rbboxes of each level \ of each image. """ num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_anchors = self.anchor_generator.grid_priors( featmap_sizes, device=device) bboxes_list = [[] for _ in range(num_imgs)] for lvl in range(num_levels): cls_score = cls_scores[lvl] bbox_pred = bbox_preds[lvl] anchors = mlvl_anchors[lvl] cls_score = cls_score.permute(0, 2, 3, 1) cls_score = cls_score.reshape(num_imgs, -1, self.num_anchors, self.cls_out_channels) cls_score, _ = cls_score.max(dim=-1, keepdim=True) best_ind = cls_score.argmax(dim=-2, keepdim=True) best_ind = best_ind.expand(-1, -1, -1, 5) bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(num_imgs, -1, self.num_anchors, 5) best_pred = bbox_pred.gather( dim=-2, index=best_ind).squeeze(dim=-2) anchors = anchors.reshape(-1, self.num_anchors, 5) for img_id in range(num_imgs): best_ind_i = best_ind[img_id] best_pred_i = best_pred[img_id] best_anchor_i = anchors.gather( dim=-2, index=best_ind_i).squeeze(dim=-2) best_bbox_i = self.bbox_coder.decode(best_anchor_i, best_pred_i) bboxes_list[img_id].append(best_bbox_i.detach()) return bboxes_list
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def refine_bboxes(self, cls_scores, bbox_preds): """This function will be used in S2ANet, whose num_anchors=1. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, 5, H, W) Returns: list[list[Tensor]]: refined rbboxes of each level of each image. """ num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_anchors = self.anchor_generator.grid_priors( featmap_sizes, device=device) bboxes_list = [[] for _ in range(num_imgs)] for lvl in range(num_levels): bbox_pred = bbox_preds[lvl] bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(num_imgs, -1, 5) anchors = mlvl_anchors[lvl] for img_id in range(num_imgs): bbox_pred_i = bbox_pred[img_id] decode_bbox_i = self.bbox_coder.decode(anchors, bbox_pred_i) bboxes_list[img_id].append(decode_bbox_i.detach()) return bboxes_list
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.