Shortcuts

Source code for mmrotate.models.dense_heads.rotated_atss_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core import images_to_levels, multi_apply, unmap

from mmrotate.core import obb2hbb, rotated_anchor_inside_flags
from ..builder import ROTATED_HEADS
from .rotated_retina_head import RotatedRetinaHead
from .utils import get_num_level_anchors_inside


[docs]@ROTATED_HEADS.register_module() class RotatedATSSHead(RotatedRetinaHead): r"""An anchor-based head used in `ATSS <https://arxiv.org/abs/1912.02424>`_. The head contains two subnetworks. The first classifies anchor boxes and the second regresses deltas for the anchors. """ # noqa: W605 def _get_targets_single(self, flat_anchors, valid_flags, num_level_anchors, gt_bboxes, gt_bboxes_ignore, gt_labels, img_meta, label_channels=1, unmap_outputs=True): """Compute regression and classification targets for anchors in a single image. Args: flat_anchors (torch.Tensor): Multi-level anchors of the image, which are concatenated into a single tensor of shape \ (num_anchors, 5) valid_flags (torch.Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of \ shape (num_anchors,). num_level_anchors (torch.Tensor): Number of anchors of each \ scale level gt_bboxes (torch.Tensor): Ground truth bboxes of the image, shape (num_gts, 5). img_meta (dict): Meta info of the image. gt_bboxes_ignore (torch.Tensor): Ground truth bboxes to be \ ignored, shape (num_ignored_gts, 5). img_meta (dict): Meta info of the image. gt_labels (torch.Tensor): Ground truth labels of each box, shape (num_gts,). label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original \ set of anchors. Returns: tuple: labels_list (list[Tensor]): Labels of all anchor label_weights_list (list[Tensor]): Label weights of all anchor bbox_targets_list (list[Tensor]): BBox targets of all anchor bbox_weights_list (list[Tensor]): BBox weights of all anchor pos_inds (int): Indices of positive anchor neg_inds (int): Indices of negative anchor sampling_result: object `SamplingResult`, sampling result. """ inside_flags = rotated_anchor_inside_flags( flat_anchors, valid_flags, img_meta['img_shape'][:2], self.train_cfg.allowed_border) if not inside_flags.any(): return (None, ) * 7 # assign gt and sample anchors anchors = flat_anchors[inside_flags, :] num_level_anchors_inside = get_num_level_anchors_inside( num_level_anchors, inside_flags) if self.assign_by_circumhbbox is not None: gt_bboxes_assign = obb2hbb(gt_bboxes, self.assign_by_circumhbbox) assign_result = self.assigner.assign( anchors, num_level_anchors_inside, gt_bboxes_assign, gt_bboxes_ignore, None if self.sampling else gt_labels) else: assign_result = self.assigner.assign( anchors, num_level_anchors_inside, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels) sampling_result = self.sampler.sample(assign_result, anchors, gt_bboxes) num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) labels = anchors.new_full((num_valid_anchors, ), self.num_classes, dtype=torch.long) label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: if not self.reg_decoded_bbox: pos_bbox_targets = self.bbox_coder.encode( sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) else: pos_bbox_targets = sampling_result.pos_gt_bboxes bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 if gt_labels is None: # Only rpn gives gt_labels as None # Foreground is the first class since v2.5.0 labels[pos_inds] = 0 else: labels[pos_inds] = gt_labels[ sampling_result.pos_assigned_gt_inds] if self.train_cfg.pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = self.train_cfg.pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # map up to original set of anchors if unmap_outputs: num_total_anchors = flat_anchors.size(0) labels = unmap( labels, num_total_anchors, inside_flags, fill=self.num_classes) # fill bg label label_weights = unmap(label_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds, sampling_result)
[docs] def get_targets(self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, gt_bboxes_ignore_list=None, gt_labels_list=None, label_channels=1, unmap_outputs=True, return_sampling_results=False): """Compute regression and classification targets for anchors in multiple images. Args: anchor_list (list[list[Tensor]]): Multi level anchors of each \ image. The outer list indicates images, and the inner list \ corresponds to feature levels of the image. Each element of \ the inner list is a tensor of shape (num_anchors, 5). valid_flag_list (list[list[Tensor]]): Multi level valid flags of \ each image. The outer list indicates images, and the inner \ list corresponds to feature levels of the image. Each element \ of the inner list is a tensor of shape (num_anchors, ) gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. img_metas (list[dict]): Meta info of each image. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be \ ignored. gt_labels_list (list[Tensor]): Ground truth labels of each box. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original \ set of anchors. Returns: tuple: Usually returns a tuple containing learning targets. - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of \ each level - bbox_targets_list (list[Tensor]): BBox targets of each level - bbox_weights_list (list[Tensor]): BBox weights of each level - num_total_pos (int): Number of positive samples in all images - num_total_neg (int): Number of negative samples in all images additional_returns: This function enables user-defined returns \ from self._get_targets_single`. These returns are currently \ refined to properties at each feature map (HxW dimension). The results will be concatenated after the end """ num_imgs = len(img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] num_level_anchors_list = [num_level_anchors] * num_imgs # concat all level anchors to a single tensor concat_anchor_list = [] concat_valid_flag_list = [] for i in range(num_imgs): assert len(anchor_list[i]) == len(valid_flag_list[i]) concat_anchor_list.append(torch.cat(anchor_list[i])) concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) # compute targets for each image if gt_bboxes_ignore_list is None: gt_bboxes_ignore_list = [None for _ in range(num_imgs)] if gt_labels_list is None: gt_labels_list = [None for _ in range(num_imgs)] results = multi_apply( self._get_targets_single, concat_anchor_list, concat_valid_flag_list, num_level_anchors_list, gt_bboxes_list, gt_bboxes_ignore_list, gt_labels_list, img_metas, label_channels=label_channels, unmap_outputs=unmap_outputs) (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] rest_results = list(results[7:]) # user-added return values # no valid anchors if any([labels is None for labels in all_labels]): return None # sampled anchors of all images num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) # split targets to a list w.r.t. multiple levels labels_list = images_to_levels(all_labels, num_level_anchors) label_weights_list = images_to_levels(all_label_weights, num_level_anchors) bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) res = (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) if return_sampling_results: res = res + (sampling_results_list, ) for i, r in enumerate(rest_results): # user-added return values rest_results[i] = images_to_levels(r, num_level_anchors) return res + tuple(rest_results)
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.