Shortcuts

Source code for mmrotate.models.detectors.s2anet

# Copyright (c) OpenMMLab. All rights reserved.
from mmrotate.core import rbbox2result
from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck
from .base import RotatedBaseDetector
from .utils import AlignConvModule


[docs]@ROTATED_DETECTORS.register_module() class S2ANet(RotatedBaseDetector): """Implementation of `Align Deep Features for Oriented Object Detection.`__ __ https://ieeexplore.ieee.org/document/9377550 """ def __init__(self, backbone, neck=None, fam_head=None, align_cfgs=None, odm_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(S2ANet, self).__init__() backbone.pretrained = pretrained self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) if train_cfg is not None: fam_head.update(train_cfg=train_cfg['fam_cfg']) fam_head.update(test_cfg=test_cfg) self.fam_head = build_head(fam_head) self.align_conv_type = align_cfgs['type'] self.align_conv_size = align_cfgs['kernel_size'] self.feat_channels = align_cfgs['channels'] self.featmap_strides = align_cfgs['featmap_strides'] if self.align_conv_type == 'AlignConv': self.align_conv = AlignConvModule(self.feat_channels, self.featmap_strides, self.align_conv_size) if train_cfg is not None: odm_head.update(train_cfg=train_cfg['odm_cfg']) odm_head.update(test_cfg=test_cfg) self.odm_head = build_head(odm_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmedetection/tools/get_flops.py` """ x = self.extract_feat(img) outs = self.fam_head(x) rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) return outs
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """Forward function of S2ANet.""" losses = dict() x = self.extract_feat(img) outs = self.fam_head(x) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) loss_base = self.fam_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) for name, value in loss_base.items(): losses[f'fam.{name}'] = value rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) loss_refine = self.odm_head.loss( *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, rois=rois) for name, value in loss_refine.items(): losses[f'odm.{name}'] = value return losses
[docs] def simple_test(self, img, img_meta, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. \ The outer list corresponds to each image. The inner list \ corresponds to each class. """ x = self.extract_feat(img) outs = self.fam_head(x) rois = self.fam_head.refine_bboxes(*outs) # rois: list(indexed by images) of list(indexed by levels) align_feat = self.align_conv(x, rois) outs = self.odm_head(align_feat) bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_list = self.odm_head.get_bboxes(*bbox_inputs, rois=rois) bbox_results = [ rbbox2result(det_bboxes, det_labels, self.odm_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs): """Test function with test time augmentation.""" raise NotImplementedError
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.