Shortcuts

Source code for mmrotate.models.roi_heads.rotate_standard_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta

import torch
from mmcv.runner import BaseModule
from mmdet.core import bbox2roi

from mmrotate.core import build_assigner, build_sampler, obb2xyxy, rbbox2result
from ..builder import (ROTATED_HEADS, build_head, build_roi_extractor,
                       build_shared_head)


[docs]@ROTATED_HEADS.register_module() class RotatedStandardRoIHead(BaseModule, metaclass=ABCMeta): """Simplest base rotated roi head including one bbox head. Args: bbox_roi_extractor (dict, optional): Config of ``bbox_roi_extractor``. bbox_head (dict, optional): Config of ``bbox_head``. shared_head (dict, optional): Config of ``shared_head``. train_cfg (dict, optional): Config of train. test_cfg (dict, optional): Config of test. pretrained (str, optional): Path of pretrained weight. init_cfg (dict, optional): Config of initialization. version (str, optional): Angle representations. Defaults to 'oc'. """ def __init__(self, bbox_roi_extractor=None, bbox_head=None, shared_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None, version='oc'): super(RotatedStandardRoIHead, self).__init__(init_cfg) self.train_cfg = train_cfg self.test_cfg = test_cfg self.version = version if shared_head is not None: shared_head.pretrained = pretrained self.shared_head = build_shared_head(shared_head) if bbox_head is not None: self.init_bbox_head(bbox_roi_extractor, bbox_head) self.init_assigner_sampler() self.with_bbox = True if bbox_head is not None else False self.with_shared_head = True if shared_head is not None else False
[docs] def init_assigner_sampler(self): """Initialize assigner and sampler.""" self.bbox_assigner = None self.bbox_sampler = None if self.train_cfg: self.bbox_assigner = build_assigner(self.train_cfg.assigner) self.bbox_sampler = build_sampler( self.train_cfg.sampler, context=self)
[docs] def init_bbox_head(self, bbox_roi_extractor, bbox_head): """Initialize ``bbox_head``. Args: bbox_roi_extractor (dict): Config of ``bbox_roi_extractor``. bbox_head (dict): Config of ``bbox_head``. """ self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor) self.bbox_head = build_head(bbox_head)
[docs] def forward_dummy(self, x, proposals): """Dummy forward function. Args: x (list[Tensors]): list of multi-level img features. proposals (list[Tensors]): list of region proposals. Returns: list[Tensors]: list of region of interest. """ outs = () rois = bbox2roi([proposals]) if self.with_bbox: bbox_results = self._bbox_forward(x, rois) outs = outs + (bbox_results['cls_score'], bbox_results['bbox_pred']) return outs
[docs] def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): """ Args: x (list[Tensor]): list of multi-level img features. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. proposals (list[Tensors]): list of region proposals. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 5) in [cx, cy, w, h, a] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. Always set to None. Returns: dict[str, Tensor]: a dictionary of loss components. """ # assign gts and sample proposals if self.with_bbox: num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_hbboxes = obb2xyxy(gt_bboxes[i], self.version) assign_result = self.bbox_assigner.assign( proposal_list[i], gt_hbboxes, gt_bboxes_ignore[i], gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposal_list[i], gt_hbboxes, gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) if gt_bboxes[i].numel() == 0: sampling_result.pos_gt_bboxes = gt_bboxes[i].new( (0, gt_bboxes[0].size(-1))).zero_() else: sampling_result.pos_gt_bboxes = \ gt_bboxes[i][sampling_result.pos_assigned_gt_inds, :] sampling_results.append(sampling_result) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train(x, sampling_results, gt_bboxes, gt_labels, img_metas) losses.update(bbox_results['loss_bbox']) return losses
def _bbox_forward(self, x, rois): """Box head forward function used in both training and testing. Args: x (list[Tensor]): list of multi-level img features. rois (list[Tensors]): list of region of interests. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) return bbox_results def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas): """Run forward function and calculate loss for box head in training. Args: x (list[Tensor]): list of multi-level img features. sampling_results (list[Tensor]): list of sampling results. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 5) in [cx, cy, w, h, a] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_results = self._bbox_forward(x, rois) bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, self.train_cfg) loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *bbox_targets) bbox_results.update(loss_bbox=loss_bbox) return bbox_results
[docs] async def async_simple_test(self, x, proposal_list, img_metas, rescale=False): """Async test without augmentation. Args: x (list[Tensor]): list of multi-level img features. proposal_list (list[Tensors]): list of region proposals. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. rescale (bool): If True, return boxes in original image space. Default: False. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = await self.async_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) return bbox_results
[docs] def simple_test(self, x, proposal_list, img_metas, rescale=False): """Test without augmentation. Args: x (list[Tensor]): list of multi-level img features. proposal_list (list[Tensors]): list of region proposals. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. rescale (bool): If True, return boxes in original image space. Default: False. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = [ rbbox2result(det_bboxes[i], det_labels[i], self.bbox_head.num_classes) for i in range(len(det_bboxes)) ] return bbox_results
[docs] def aug_test(self, x, proposal_list, img_metas, rescale=False): """Test with augmentations.""" raise NotImplementedError
[docs] def simple_test_bboxes(self, x, img_metas, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation. Args: x (tuple[Tensor]): Feature maps of all scale level. img_metas (list[dict]): Image meta info. proposals (List[Tensor]): Region proposals. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. rescale (bool): If True, return boxes in original image space. Default: False. Returns: tuple[list[Tensor], list[Tensor]]: The first list contains \ the boxes of the corresponding image in a batch, each \ tensor has the shape (num_boxes, 5) and last dimension \ 5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor \ in the second list is the labels with shape (num_boxes, ). \ The length of both lists should be equal to batch_size. """ rois = bbox2roi(proposals) if rois.shape[0] == 0: batch_size = len(proposals) det_bbox = rois.new_zeros(0, 5) det_label = rois.new_zeros((0, ), dtype=torch.long) if rcnn_test_cfg is None: det_bbox = det_bbox[:, :4] det_label = rois.new_zeros( (0, self.bbox_head.fc_cls.out_features)) # There is no proposal in the whole batch return [det_bbox] * batch_size, [det_label] * batch_size bbox_results = self._bbox_forward(x, rois) img_shapes = tuple(meta['img_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) # split batch bbox prediction back to each image cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] num_proposals_per_img = tuple(len(p) for p in proposals) rois = rois.split(num_proposals_per_img, 0) cls_score = cls_score.split(num_proposals_per_img, 0) # some detector with_reg is False, bbox_pred will be None if bbox_pred is not None: # TODO move this to a sabl_roi_head # the bbox prediction of some detectors like SABL is not Tensor if isinstance(bbox_pred, torch.Tensor): bbox_pred = bbox_pred.split(num_proposals_per_img, 0) else: bbox_pred = self.bbox_head.bbox_pred_split( bbox_pred, num_proposals_per_img) else: bbox_pred = (None, ) * len(proposals) # apply bbox post-processing to each image individually det_bboxes = [] det_labels = [] for i in range(len(proposals)): if rois[i].shape[0] == 0: # There is no proposal in the single image det_bbox = rois[i].new_zeros(0, 5) det_label = rois[i].new_zeros((0, ), dtype=torch.long) if rcnn_test_cfg is None: det_bbox = det_bbox[:, :4] det_label = rois[i].new_zeros( (0, self.bbox_head.fc_cls.out_features)) else: det_bbox, det_label = self.bbox_head.get_bboxes( rois[i], cls_score[i], bbox_pred[i], img_shapes[i], scale_factors[i], rescale=rescale, cfg=rcnn_test_cfg) det_bboxes.append(det_bbox) det_labels.append(det_label) return det_bboxes, det_labels
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.