
Source code for mmrotate.models.dense_heads.kfiou_rotate_retina_head

# Copyright (c) SJTU. All rights reserved.
from ..builder import ROTATED_HEADS
from .rotated_retina_head import RotatedRetinaHead

[docs]@ROTATED_HEADS.register_module() class KFIoURRetinaHead(RotatedRetinaHead): """Rotated Anchor-based head for KFIoU. The difference from `RRetinaHead` is that its loss_bbox requires bbox_pred, bbox_targets, pred_decode and targets_decode as inputs. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. stacked_convs (int, optional): Number of stacked convolutions. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. anchor_generator (dict): Config dict for anchor generator init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, stacked_convs=4, conv_cfg=None, norm_cfg=None, anchor_generator=dict( type='AnchorGenerator', octave_base_scale=4, scales_per_octave=3, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128]), init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='retina_cls', std=0.01, bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None super(KFIoURRetinaHead, self).__init__( num_classes=num_classes, in_channels=in_channels, stacked_convs=stacked_convs, conv_cfg=conv_cfg, norm_cfg=norm_cfg, anchor_generator=anchor_generator, init_cfg=init_cfg, **kwargs)
[docs] def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, bbox_targets, bbox_weights, num_total_samples): """Compute loss of a single scale level. Args: cls_score (torch.Tensor): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W). bbox_pred (torch.Tensor): Box energies / deltas for each scale level with shape (N, num_anchors * 5, H, W). anchors (torch.Tensor): Box reference for each scale level with shape (N, num_total_anchors, 5). labels (torch.Tensor): Labels of each anchors with shape (N, num_total_anchors). label_weights (torch.Tensor): Label weights of each anchor with shape (N, num_total_anchors) bbox_targets (torch.Tensor): BBox regression targets of each anchor weight shape (N, num_total_anchors, 5). bbox_weights (torch.Tensor): BBox regression loss weights of each anchor with shape (N, num_total_anchors, 5). num_total_samples (int): If sampling, num total samples equal to the number of total anchors; Otherwise, it is the number of positive anchors. Returns: tuple (torch.Tensor): - loss_cls (torch.Tensor): cls. loss for each scale level. - loss_bbox (torch.Tensor): reg. loss for each scale level. """ # classification loss labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=num_total_samples) # regression loss bbox_targets = bbox_targets.reshape(-1, 5) bbox_weights = bbox_weights.reshape(-1, 5) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5) anchors = anchors.reshape(-1, 5) bbox_pred_decode = self.bbox_coder.decode(anchors, bbox_pred) bbox_targets_decode = self.bbox_coder.decode(anchors, bbox_targets) loss_bbox = self.loss_bbox( bbox_pred, bbox_targets, bbox_weights, pred_decode=bbox_pred_decode, targets_decode=bbox_targets_decode, avg_factor=num_total_samples) return loss_cls, loss_bbox
Read the Docs v: v0.3.3
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.