Shortcuts

Source code for mmrotate.models.detectors.single_stage

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

from mmrotate.core import rbbox2result
from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck
from .base import RotatedBaseDetector


[docs]@ROTATED_DETECTORS.register_module() class RotatedSingleStageDetector(RotatedBaseDetector): """Base class for rotated single-stage detectors. Single-stage detectors directly and densely predict bounding boxes on the output features of the backbone+neck. """ def __init__(self, backbone, neck=None, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(RotatedSingleStageDetector, self).__init__(init_cfg) if pretrained: warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') backbone.pretrained = pretrained self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) self.bbox_head = build_head(bbox_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` """ x = self.extract_feat(img) outs = self.bbox_head(x) return outs
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ super(RotatedSingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) return losses
[docs] def simple_test(self, img, img_metas, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. \ The outer list corresponds to each image. The inner list \ corresponds to each class. """ x = self.extract_feat(img) outs = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) bbox_results = [ rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Test function with test time augmentation. Args: imgs (list[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. img_metas (list[list[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. each dict has image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. \ The outer list corresponds to each image. The inner list corresponds to each class. """ assert hasattr(self.bbox_head, 'aug_test'), \ f'{self.bbox_head.__class__.__name__}' \ ' does not support test-time augmentation' feats = self.extract_feats(imgs) results_list = self.bbox_head.aug_test( feats, img_metas, rescale=rescale) bbox_results = [ rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in results_list ] return bbox_results
Read the Docs v: stable
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.