Shortcuts

Source code for mmrotate.models.roi_heads.rotate_standard_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta

import torch
from mmcv.runner import BaseModule
from mmdet.core import bbox2roi

from mmrotate.core import build_assigner, build_sampler, obb2xyxy, rbbox2result
from ..builder import (ROTATED_HEADS, build_head, build_roi_extractor,
                       build_shared_head)


[docs]@ROTATED_HEADS.register_module() class RotatedStandardRoIHead(BaseModule, metaclass=ABCMeta): """Simplest base rotated roi head including one bbox head. Args: bbox_roi_extractor (dict, optional): Config of ``bbox_roi_extractor``. bbox_head (dict, optional): Config of ``bbox_head``. shared_head (dict, optional): Config of ``shared_head``. train_cfg (dict, optional): Config of train. test_cfg (dict, optional): Config of test. pretrained (str, optional): Path of pretrained weight. init_cfg (dict, optional): Config of initialization. version (str, optional): Angle representations. Defaults to 'oc'. """ def __init__(self, bbox_roi_extractor=None, bbox_head=None, shared_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None, version='oc'): super(RotatedStandardRoIHead, self).__init__(init_cfg) self.train_cfg = train_cfg self.test_cfg = test_cfg self.version = version if shared_head is not None: shared_head.pretrained = pretrained self.shared_head = build_shared_head(shared_head) if bbox_head is not None: self.init_bbox_head(bbox_roi_extractor, bbox_head) self.init_assigner_sampler() self.with_bbox = True if bbox_head is not None else False self.with_shared_head = True if shared_head is not None else False
[docs] def init_assigner_sampler(self): """Initialize assigner and sampler.""" self.bbox_assigner = None self.bbox_sampler = None if self.train_cfg: self.bbox_assigner = build_assigner(self.train_cfg.assigner) self.bbox_sampler = build_sampler( self.train_cfg.sampler, context=self)
[docs] def init_bbox_head(self, bbox_roi_extractor, bbox_head): """Initialize ``bbox_head``. Args: bbox_roi_extractor (dict): Config of ``bbox_roi_extractor``. bbox_head (dict): Config of ``bbox_head``. """ self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor) self.bbox_head = build_head(bbox_head)
[docs] def forward_dummy(self, x, proposals): """Dummy forward function. Args: x (list[Tensors]): list of multi-level img features. proposals (list[Tensors]): list of region proposals. Returns: list[Tensors]: list of region of interest. """ outs = () rois = bbox2roi([proposals]) if self.with_bbox: bbox_results = self._bbox_forward(x, rois) outs = outs + (bbox_results['cls_score'], bbox_results['bbox_pred']) return outs
[docs] def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): """ Args: x (list[Tensor]): list of multi-level img features. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. proposals (list[Tensors]): list of region proposals. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 5) in [cx, cy, w, h, a] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. Always set to None. Returns: dict[str, Tensor]: a dictionary of loss components. """ # assign gts and sample proposals if self.with_bbox: num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): gt_hbboxes = obb2xyxy(gt_bboxes[i], self.version) assign_result = self.bbox_assigner.assign( proposal_list[i], gt_hbboxes, gt_bboxes_ignore[i], gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposal_list[i], gt_hbboxes, gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) if gt_bboxes[i].numel() == 0: sampling_result.pos_gt_bboxes = gt_bboxes[i].new( (0, gt_bboxes[0].size(-1))).zero_() else: sampling_result.pos_gt_bboxes = \ gt_bboxes[i][sampling_result.pos_assigned_gt_inds, :] sampling_results.append(sampling_result) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train(x, sampling_results, gt_bboxes, gt_labels, img_metas) losses.update(bbox_results['loss_bbox']) return losses
def _bbox_forward(self, x, rois): """Box head forward function used in both training and testing. Args: x (list[Tensor]): list of multi-level img features. rois (list[Tensors]): list of region of interests. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) return bbox_results def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas): """Run forward function and calculate loss for box head in training. Args: x (list[Tensor]): list of multi-level img features. sampling_results (list[Tensor]): list of sampling results. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 5) in [cx, cy, w, h, a] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_results = self._bbox_forward(x, rois) bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, self.train_cfg) loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *bbox_targets) bbox_results.update(loss_bbox=loss_bbox) return bbox_results
[docs] async def async_simple_test(self, x, proposal_list, img_metas, rescale=False): """Async test without augmentation. Args: x (list[Tensor]): list of multi-level img features. proposal_list (list[Tensors]): list of region proposals. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. rescale (bool): If True, return boxes in original image space. Default: False. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = await self.async_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) return bbox_results
[docs] def simple_test(self, x, proposal_list, img_metas, rescale=False): """Test without augmentation. Args: x (list[Tensor]): list of multi-level img features. proposal_list (list[Tensors]): list of region proposals. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. rescale (bool): If True, return boxes in original image space. Default: False. Returns: dict[str, Tensor]: a dictionary of bbox_results. """ assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=rescale) bbox_results = [ rbbox2result(det_bboxes[i], det_labels[i], self.bbox_head.num_classes) for i in range(len(det_bboxes)) ] return bbox_results
[docs] def aug_test(self, x, proposal_list, img_metas, rescale=False): """Test with augmentations.""" raise NotImplementedError
[docs] def simple_test_bboxes(self, x, img_metas, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation. Args: x (tuple[Tensor]): Feature maps of all scale level. img_metas (list[dict]): Image meta info. proposals (List[Tensor]): Region proposals. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. rescale (bool): If True, return boxes in original image space. Default: False. Returns: tuple[list[Tensor], list[Tensor]]: The first list contains \ the boxes of the corresponding image in a batch, each \ tensor has the shape (num_boxes, 5) and last dimension \ 5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor \ in the second list is the labels with shape (num_boxes, ). \ The length of both lists should be equal to batch_size. """ rois = bbox2roi(proposals) if rois.shape[0] == 0: batch_size = len(proposals) det_bbox = rois.new_zeros(0, 5) det_label = rois.new_zeros((0, ), dtype=torch.long) if rcnn_test_cfg is None: det_bbox = det_bbox[:, :4] det_label = rois.new_zeros( (0, self.bbox_head.fc_cls.out_features)) # There is no proposal in the whole batch return [det_bbox] * batch_size, [det_label] * batch_size bbox_results = self._bbox_forward(x, rois) img_shapes = tuple(meta['img_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) # split batch bbox prediction back to each image cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] num_proposals_per_img = tuple(len(p) for p in proposals) rois = rois.split(num_proposals_per_img, 0) cls_score = cls_score.split(num_proposals_per_img, 0) # some detector with_reg is False, bbox_pred will be None if bbox_pred is not None: # TODO move this to a sabl_roi_head # the bbox prediction of some detectors like SABL is not Tensor if isinstance(bbox_pred, torch.Tensor): bbox_pred = bbox_pred.split(num_proposals_per_img, 0) else: bbox_pred = self.bbox_head.bbox_pred_split( bbox_pred, num_proposals_per_img) else: bbox_pred = (None, ) * len(proposals) # apply bbox post-processing to each image individually det_bboxes = [] det_labels = [] for i in range(len(proposals)): if rois[i].shape[0] == 0: # There is no proposal in the single image det_bbox = rois[i].new_zeros(0, 5) det_label = rois[i].new_zeros((0, ), dtype=torch.long) if rcnn_test_cfg is None: det_bbox = det_bbox[:, :4] det_label = rois[i].new_zeros( (0, self.bbox_head.fc_cls.out_features)) else: det_bbox, det_label = self.bbox_head.get_bboxes( rois[i], cls_score[i], bbox_pred[i], img_shapes[i], scale_factors[i], rescale=rescale, cfg=rcnn_test_cfg) det_bboxes.append(det_bbox) det_labels.append(det_label) return det_bboxes, det_labels
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.