Shortcuts

Source code for mmrotate.models.dense_heads.kfiou_rotate_retina_refine_head

# Copyright (c) SJTU. All rights reserved.
import torch
from mmcv.runner import force_fp32

from ..builder import ROTATED_HEADS
from .kfiou_rotate_retina_head import KFIoURRetinaHead


[docs]@ROTATED_HEADS.register_module() class KFIoURRetinaRefineHead(KFIoURRetinaHead): """Rotational Anchor-based refine head. The difference from `RRetinaRefineHead` is that its loss_bbox requires bbox_pred, bbox_targets, pred_decode and targets_decode as inputs. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. stacked_convs (int, optional): Number of stacked convolutions. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. anchor_generator (dict): Config dict for anchor generator bbox_coder (dict): Config of bounding box coder. init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, stacked_convs=4, conv_cfg=None, norm_cfg=None, anchor_generator=dict( type='PseudoAnchorGenerator', strides=[8, 16, 32, 64, 128]), bbox_coder=dict( type='DeltaXYWHABBoxCoder', target_means=(.0, .0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)), init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='retina_cls', std=0.01, bias_prob=0.01)), **kwargs): self.bboxes_as_anchors = None super(KFIoURRetinaRefineHead, self).__init__( num_classes=num_classes, in_channels=in_channels, stacked_convs=stacked_convs, conv_cfg=conv_cfg, norm_cfg=norm_cfg, anchor_generator=anchor_generator, bbox_coder=bbox_coder, init_cfg=init_cfg, **kwargs)
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def refine_bboxes(self, cls_scores, bbox_preds, rois): """Refine predicted bounding boxes at each position of the feature maps. This method will be used in R3Det in refinement stages. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, 5, H, W) rois (list[list[Tensor]]): input rbboxes of each level of each image. rois output by former stages and are to be refined Returns: list[list[Tensor]]: best or refined rbboxes of each level of each \ image. """ num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) bboxes_list = [[] for _ in range(num_imgs)] assert rois is not None mlvl_rois = [torch.cat(r) for r in zip(*rois)] for lvl in range(num_levels): bbox_pred = bbox_preds[lvl] rois = mlvl_rois[lvl] assert bbox_pred.size(1) == 5 bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(-1, 5) refined_bbox = self.bbox_coder.decode(rois, bbox_pred) refined_bbox = refined_bbox.reshape(num_imgs, -1, 5) for img_id in range(num_imgs): bboxes_list[img_id].append(refined_bbox[img_id].detach()) return bboxes_list
[docs] def get_anchors(self, featmap_sizes, img_metas, device='cuda'): """Get anchors according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. img_metas (list[dict]): Image meta info. bboxes_as_anchors (list[list[Tensor]]) bboxes of levels of images. before further regression just like anchors. device (torch.device | str): Device for returned tensors Returns: tuple (list[Tensor]): - anchor_list (list[Tensor]): Anchors of each image - valid_flag_list (list[Tensor]): Valid flags of each image """ anchor_list = [[ bboxes_img_lvl.clone().detach() for bboxes_img_lvl in bboxes_img ] for bboxes_img in self.bboxes_as_anchors] # for each image, we compute valid flags of multi level anchors valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): multi_level_flags = self.anchor_generator.valid_flags( featmap_sizes, img_meta['pad_shape'], device) valid_flag_list.append(multi_level_flags) return anchor_list, valid_flag_list
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, rois=None, gt_bboxes_ignore=None): """Loss function of KFIoURRetinaRefineHead.""" assert rois is not None self.bboxes_as_anchors = rois return super(KFIoURRetinaRefineHead, self).loss( cls_scores=cls_scores, bbox_preds=bbox_preds, gt_bboxes=gt_bboxes, gt_labels=gt_labels, img_metas=img_metas, gt_bboxes_ignore=gt_bboxes_ignore)
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg=None, rescale=False, rois=None): """Transform network output for a batch into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 5, H, W) img_metas (list[dict]): size / scale info for each image cfg (mmcv.Config): test / postprocessing configuration rois (list[list[Tensor]]): input rbboxes of each level of each image. rois output by former stages and are to be refined rescale (bool): if True, return boxes in original image space Returns: list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple. The first item is an (n, 6) tensor, where the first 5 columns are bounding box positions (xc, yc, w, h, a) and the 6-th column is a score between 0 and 1. The second item is a (n,) tensor where each item is the class index of the corresponding box. """ num_levels = len(cls_scores) assert len(cls_scores) == len(bbox_preds) assert rois is not None result_list = [] for img_id, _ in enumerate(img_metas): cls_score_list = [ cls_scores[i][img_id].detach() for i in range(num_levels) ] bbox_pred_list = [ bbox_preds[i][img_id].detach() for i in range(num_levels) ] img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, rois[img_id], img_shape, scale_factor, cfg, rescale) result_list.append(proposals) return result_list
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.