Shortcuts

Source code for mmrotate.models.dense_heads.csl_rotated_fcos_head

# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn as nn
from mmcv.cnn import Scale
from mmcv.runner import force_fp32
from mmdet.core import reduce_mean

from mmrotate.core import build_bbox_coder, multiclass_nms_rotated
from ..builder import ROTATED_HEADS
from .rotated_anchor_free_head import RotatedAnchorFreeHead
from .rotated_fcos_head import RotatedFCOSHead

INF = 1e8


[docs]@ROTATED_HEADS.register_module() class CSLRFCOSHead(RotatedFCOSHead): """Use `Circular Smooth Label (CSL) <https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40>`_ . in `FCOS <https://arxiv.org/abs/1904.01355>`_. Args: separate_angle (bool): If true, angle prediction is separated from bbox regression loss. In CSL only support True. Default: True. scale_angle (bool): If true, add scale to angle pred branch. In CSL only support False. Default: False. angle_coder (dict): Config of angle coder. """ # noqa: E501 def __init__(self, separate_angle=True, scale_angle=False, angle_coder=dict( type='CSLCoder', angle_version='le90', omega=1, window='gaussian', radius=6), **kwargs): self.angle_coder = build_bbox_coder(angle_coder) assert separate_angle, 'Only support separate angle in CSL' assert scale_angle is False, 'Only support no scale angle in CSL' self.coding_len = self.angle_coder.coding_len super().__init__( separate_angle=separate_angle, scale_angle=scale_angle, **kwargs) def _init_layers(self): """Initialize layers of the head.""" RotatedAnchorFreeHead._init_layers(self) self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) self.conv_angle = nn.Conv2d( self.feat_channels, self.coding_len, 3, padding=1) self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
[docs] @force_fp32( apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centernesses')) def loss(self, cls_scores, bbox_preds, angle_preds, centernesses, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute loss of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_points * 4. angle_preds (list[Tensor]): Box angle for each scale level, \ each is a 4D-tensor, the channel number is num_points * 1. centernesses (list[Tensor]): centerness for each scale level, each is a 4D-tensor, the channel number is num_points * 1. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == len(bbox_preds) \ == len(angle_preds) == len(centernesses) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] all_level_points = self.prior_generator.grid_priors( featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) labels, bbox_targets, angle_targets = self.get_targets( all_level_points, gt_bboxes, gt_labels) num_imgs = cls_scores[0].size(0) # flatten cls_scores, bbox_preds and centerness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ] flatten_angle_preds = [ angle_pred.permute(0, 2, 3, 1).reshape(-1, self.coding_len) for angle_pred in angle_preds ] flatten_centerness = [ centerness.permute(0, 2, 3, 1).reshape(-1) for centerness in centernesses ] flatten_cls_scores = torch.cat(flatten_cls_scores) flatten_bbox_preds = torch.cat(flatten_bbox_preds) flatten_angle_preds = torch.cat(flatten_angle_preds) flatten_centerness = torch.cat(flatten_centerness) flatten_labels = torch.cat(labels) flatten_bbox_targets = torch.cat(bbox_targets) flatten_angle_targets = torch.cat(angle_targets) # repeat points to align with bbox_preds flatten_points = torch.cat( [points.repeat(num_imgs, 1) for points in all_level_points]) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) num_pos = torch.tensor( len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) num_pos = max(reduce_mean(num_pos), 1.0) loss_cls = self.loss_cls( flatten_cls_scores, flatten_labels, avg_factor=num_pos) pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_angle_preds = flatten_angle_preds[pos_inds] pos_centerness = flatten_centerness[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] pos_angle_targets = flatten_angle_targets[pos_inds] pos_centerness_targets = self.centerness_target(pos_bbox_targets) # centerness weighted iou loss centerness_denorm = max( reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) if len(pos_inds) > 0: pos_points = flatten_points[pos_inds] if self.seprate_angle: bbox_coder = self.h_bbox_coder else: bbox_coder = self.bbox_coder pos_bbox_preds = torch.cat([pos_bbox_preds, pos_angle_preds], dim=-1) pos_bbox_targets = torch.cat( [pos_bbox_targets, pos_angle_targets], dim=-1) pos_decoded_bbox_preds = bbox_coder.decode(pos_points, pos_bbox_preds) pos_decoded_target_preds = bbox_coder.decode( pos_points, pos_bbox_targets) loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, weight=pos_centerness_targets, avg_factor=centerness_denorm) if self.seprate_angle: loss_angle = self.loss_angle( pos_angle_preds, pos_angle_targets, avg_factor=num_pos) loss_centerness = self.loss_centerness( pos_centerness, pos_centerness_targets, avg_factor=num_pos) else: loss_bbox = pos_bbox_preds.sum() loss_centerness = pos_centerness.sum() if self.seprate_angle: loss_angle = pos_angle_preds.sum() if self.seprate_angle: return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_angle=loss_angle, loss_centerness=loss_centerness) else: return dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)
def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges, num_points_per_lvl): """Compute regression, classification and angle targets for a single image.""" num_points = points.size(0) num_gts = gt_labels.size(0) if num_gts == 0: return gt_labels.new_full((num_points,), self.num_classes), \ gt_bboxes.new_zeros((num_points, 4)), \ gt_bboxes.new_zeros((num_points, self.coding_len)) labels, bbox_targets, angle_targets = \ super(CSLRFCOSHead, self)._get_target_single(gt_bboxes, gt_labels, points, regress_ranges, num_points_per_lvl) angle_targets = self.angle_coder.encode(angle_targets) return labels, bbox_targets, angle_targets def _get_bboxes_single(self, cls_scores, bbox_preds, angle_preds, centernesses, mlvl_points, img_shape, scale_factor, cfg, rescale=False): """Transform outputs for a single batch item into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for a single scale level Has shape (num_points * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for a single scale level with shape (num_points * 4, H, W). angle_preds (list[Tensor]): Box angle for a single scale level \ with shape (N, num_points * 1, H, W). centernesses (list[Tensor]): Centerness for a single scale level with shape (num_points * 1, H, W). mlvl_points (list[Tensor]): Box reference for a single scale level with shape (num_total_points, 4). img_shape (tuple[int]): Shape of the input image, (height, width, 3). scale_factor (ndarray): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Returns: Tensor: Labeled boxes in shape (n, 6), where the first 5 columns are bounding box positions (x, y, w, h, angle) and the 6-th column is a score between 0 and 1. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, angle_pred, centerness, points in zip( cls_scores, bbox_preds, angle_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) angle_pred = angle_pred.permute(1, 2, 0).reshape(-1, self.coding_len) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: max_scores, _ = (scores * centerness[:, None]).max(dim=1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] angle_pred = angle_pred[topk_inds, :] scores = scores[topk_inds, :] centerness = centerness[topk_inds] angle_pred = self.angle_coder.decode(angle_pred).unsqueeze(-1) bbox_pred = torch.cat([bbox_pred, angle_pred], dim=-1) bboxes = self.bbox_coder.decode( points, bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: scale_factor = mlvl_bboxes.new_tensor(scale_factor) mlvl_bboxes[..., :4] = mlvl_bboxes[..., :4] / scale_factor mlvl_scores = torch.cat(mlvl_scores) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) mlvl_centerness = torch.cat(mlvl_centerness) det_bboxes, det_labels = multiclass_nms_rotated( mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img, score_factors=mlvl_centerness) return det_bboxes, det_labels
[docs] @force_fp32( apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centerness')) def refine_bboxes(self, cls_scores, bbox_preds, angle_preds, centernesses): """This function will be used in S2ANet, whose num_anchors=1.""" num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) # device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_points = self.prior_generator.grid_priors(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) bboxes_list = [[] for _ in range(num_imgs)] for lvl in range(num_levels): bbox_pred = bbox_preds[lvl] angle_pred = angle_preds[lvl] bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(num_imgs, -1, 4) angle_pred = angle_pred.permute(0, 2, 3, 1) angle_pred = angle_pred.reshape(num_imgs, -1, self.coding_len) angle_pred = self.angle_coder.decode(angle_pred) bbox_pred = torch.cat([bbox_pred, angle_pred], dim=-1) points = mlvl_points[lvl] for img_id in range(num_imgs): bbox_pred_i = bbox_pred[img_id] decode_bbox_i = self.bbox_coder.decode(points, bbox_pred_i) bboxes_list[img_id].append(decode_bbox_i.detach()) return bboxes_list
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.