mmrotate.models.dense_heads.odm_refine_head 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32
from ..builder import ROTATED_HEADS
from ..utils import ORConv2d, RotationInvariantPooling
from .rotated_retina_head import RotatedRetinaHead
[文档]@ROTATED_HEADS.register_module()
class ODMRefineHead(RotatedRetinaHead):
"""Rotated Anchor-based refine head. It's a part of the Oriented Detection
Module (ODM), which produces orientation-sensitive features for
classification and orientation-invariant features for localization.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int, optional): Number of stacked convolutions.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
anchor_generator (dict): Config dict for anchor generator
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
stacked_convs=2,
conv_cfg=None,
norm_cfg=None,
anchor_generator=dict(
type='PseudoAnchorGenerator',
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='odm_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.bboxes_as_anchors = None
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
super(ODMRefineHead, self).__init__(
num_classes,
in_channels,
stacked_convs=2,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)
def _init_layers(self):
"""Initialize layers of the head."""
self.or_conv = ORConv2d(
self.feat_channels,
int(self.feat_channels / 8),
kernel_size=3,
padding=1,
arf_config=(1, 8))
self.or_pool = RotationInvariantPooling(256, 8)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = int(self.feat_channels / 8) if i == 0 else self.feat_channels
self.reg_convs.append(
ConvModule(
self.feat_channels,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.odm_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.odm_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 5, 3, padding=1)
[文档] def forward_single(self, x):
"""Forward feature of a single scale level.
Args:
x (torch.Tensor): Features of a single scale level.
Returns:
tuple (torch.Tensor):
- cls_score (torch.Tensor): Cls scores for a single scale \
level the channels number is num_anchors * num_classes.
- bbox_pred (torch.Tensor): Box energies / deltas for a \
single scale level, the channels number is num_anchors * 4.
"""
or_feat = self.or_conv(x)
reg_feat = or_feat
cls_feat = self.or_pool(or_feat)
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.odm_cls(cls_feat)
bbox_pred = self.odm_reg(reg_feat)
return cls_score, bbox_pred
[文档] def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
bboxes_as_anchors (list[list[Tensor]]) bboxes of levels of images.
before further regression just like anchors.
device (torch.device | str): Device for returned tensors
Returns:
tuple (list[Tensor]):
- anchor_list (list[Tensor]): Anchors of each image
- valid_flag_list (list[Tensor]): Valid flags of each image
"""
anchor_list = [[
bboxes_img_lvl.clone().detach() for bboxes_img_lvl in bboxes_img
] for bboxes_img in self.bboxes_as_anchors]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = self.anchor_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'], device)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
[文档] @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
rois=None,
gt_bboxes_ignore=None):
"""Loss function of ODMRefineHead."""
assert rois is not None
self.bboxes_as_anchors = rois
return super(ODMRefineHead, self).loss(
cls_scores=cls_scores,
bbox_preds=bbox_preds,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
img_metas=img_metas,
gt_bboxes_ignore=gt_bboxes_ignore)
[文档] @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=False,
rois=None):
"""Transform network output for a batch into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 5, H, W)
img_metas (list[dict]): size / scale info for each image
cfg (mmcv.Config): test / postprocessing configuration
rois (list[list[Tensor]]): input rbboxes of each level of
each image. rois output by former stages and are to be refined
rescale (bool): if True, return boxes in original image space
Returns:
list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple.
The first item is an (n, 6) tensor, where the first 5 columns
are bounding box positions (xc, yc, w, h, a) and the
6-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the class index of the
corresponding box.
"""
num_levels = len(cls_scores)
assert len(cls_scores) == len(bbox_preds)
assert rois is not None
result_list = []
for img_id, _ in enumerate(img_metas):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
rois[img_id], img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list