Shortcuts

Source code for mmrotate.models.detectors.s2anet

# Copyright (c) OpenMMLab. All rights reserved.
from mmrotate.core import rbbox2result
from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck
from .base import RotatedBaseDetector
from .utils import AlignConvModule


[docs]@ROTATED_DETECTORS.register_module() class S2ANet(RotatedBaseDetector): """Implementation of `Align Deep Features for Oriented Object Detection.`__ __ https://ieeexplore.ieee.org/document/9377550 """ def __init__(self, backbone, neck=None, fam_head=None, align_cfgs=None, odm_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(S2ANet, self).__init__() backbone.pretrained = pretrained self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) if train_cfg is not None: fam_head.update(train_cfg=train_cfg['fam_cfg']) fam_head.update(test_cfg=test_cfg) self.fam_head = build_head(fam_head) self.align_conv_type = align_cfgs['type'] self.align_conv_size = align_cfgs['kernel_size'] self.feat_channels = align_cfgs['channels'] self.featmap_strides = align_cfgs['featmap_strides'] if self.align_conv_type == 'AlignConv': self.align_conv = AlignConvModule(self.feat_channels, self.featmap_strides, self.align_conv_size) if train_cfg is not None: odm_head.update(train_cfg=train_cfg['odm_cfg']) odm_head.update(test_cfg=test_cfg) self.odm_head = build_head(odm_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmedetection/tools/get_flops.py` """ x = self.extract_feat(img) outs = self.fam_head(x) rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) return outs
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """Forward function of S2ANet.""" losses = dict() x = self.extract_feat(img) outs = self.fam_head(x) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) loss_base = self.fam_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) for name, value in loss_base.items(): losses[f'fam.{name}'] = value rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) loss_refine = self.odm_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, rois=rois) for name, value in loss_refine.items(): losses[f'odm.{name}'] = value return losses
[docs] def simple_test(self, img, img_meta, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. \ The outer list corresponds to each image. The inner list \ corresponds to each class. """ x = self.extract_feat(img) outs = self.fam_head(x) rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_list = self.odm_head.get_bboxes(*bbox_inputs, rois=rois) bbox_results = [ rbbox2result(det_bboxes, det_labels, self.odm_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs): """Test function with test time augmentation.""" raise NotImplementedError
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.