Source code for mmrotate.models.dense_heads.kfiou_rotate_retina_head
# Copyright (c) SJTU. All rights reserved.
from ..builder import ROTATED_HEADS
from .rotated_retina_head import RotatedRetinaHead
[docs]@ROTATED_HEADS.register_module()
class KFIoURRetinaHead(RotatedRetinaHead):
"""Rotated Anchor-based head for KFIoU. The difference from `RRetinaHead`
is that its loss_bbox requires bbox_pred, bbox_targets, pred_decode and
targets_decode as inputs.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int, optional): Number of stacked convolutions.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
anchor_generator (dict): Config dict for anchor generator
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
conv_cfg=None,
norm_cfg=None,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='retina_cls',
std=0.01,
bias_prob=0.01)),
**kwargs):
self.bboxes_as_anchors = None
super(KFIoURRetinaHead, self).__init__(
num_classes=num_classes,
in_channels=in_channels,
stacked_convs=stacked_convs,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
anchor_generator=anchor_generator,
init_cfg=init_cfg,
**kwargs)
[docs] def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples):
"""Compute loss of a single scale level.
Args:
cls_score (torch.Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (torch.Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 5, H, W).
anchors (torch.Tensor): Box reference for each scale level with
shape (N, num_total_anchors, 5).
labels (torch.Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (torch.Tensor): Label weights of each anchor with
shape (N, num_total_anchors)
bbox_targets (torch.Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 5).
bbox_weights (torch.Tensor): BBox regression loss weights of each
anchor with shape (N, num_total_anchors, 5).
num_total_samples (int): If sampling, num total samples equal to
the number of total anchors; Otherwise, it is the number of
positive anchors.
Returns:
tuple (torch.Tensor):
- loss_cls (torch.Tensor): cls. loss for each scale level.
- loss_bbox (torch.Tensor): reg. loss for each scale level.
"""
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 5)
bbox_weights = bbox_weights.reshape(-1, 5)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5)
anchors = anchors.reshape(-1, 5)
bbox_pred_decode = self.bbox_coder.decode(anchors, bbox_pred)
bbox_targets_decode = self.bbox_coder.decode(anchors, bbox_targets)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
pred_decode=bbox_pred_decode,
targets_decode=bbox_targets_decode,
avg_factor=num_total_samples)
return loss_cls, loss_bbox