Source code for mmrotate.models.detectors.r3det
# Copyright (c) SJTU. All rights reserved.
import warnings
from mmcv.runner import ModuleList
from mmrotate.core import rbbox2result
from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck
from .base import RotatedBaseDetector
from .utils import FeatureRefineModule
[docs]@ROTATED_DETECTORS.register_module()
class R3Det(RotatedBaseDetector):
"""Rotated Refinement RetinaNet."""
def __init__(self,
num_refine_stages,
backbone,
neck=None,
bbox_head=None,
frm_cfgs=None,
refine_heads=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(R3Det, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.num_refine_stages = num_refine_stages
if neck is not None:
self.neck = build_neck(neck)
if train_cfg is not None:
bbox_head.update(train_cfg=train_cfg['s0'])
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.feat_refine_module = ModuleList()
self.refine_head = ModuleList()
for i, (frm_cfg,
refine_head) in enumerate(zip(frm_cfgs, refine_heads)):
self.feat_refine_module.append(FeatureRefineModule(**frm_cfg))
if train_cfg is not None:
refine_head.update(train_cfg=train_cfg['sr'][i])
refine_head.update(test_cfg=test_cfg)
self.refine_head.append(build_head(refine_head))
self.train_cfg = train_cfg
self.test_cfg = test_cfg
[docs] def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
[docs] def forward_dummy(self, img):
"""Used for computing network flops.
See `mmedetection/tools/get_flops.py`
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
rois = self.bbox_head.filter_bboxes(*outs)
# rois: list(indexed by images) of list(indexed by levels)
for i in range(self.num_refine_stages):
x_refine = self.feat_refine_module[i](x, rois)
outs = self.refine_head[i](x_refine)
if i + 1 in range(self.num_refine_stages):
rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
return outs
[docs] def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""Forward function."""
losses = dict()
x = self.extract_feat(img)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
loss_base = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
for name, value in loss_base.items():
losses[f's0.{name}'] = value
rois = self.bbox_head.filter_bboxes(*outs)
# rois: list(indexed by images) of list(indexed by levels)
for i in range(self.num_refine_stages):
lw = self.train_cfg.stage_loss_weights[i]
x_refine = self.feat_refine_module[i](x, rois)
outs = self.refine_head[i](x_refine)
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
loss_refine = self.refine_head[i].loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, rois=rois)
for name, value in loss_refine.items():
losses[f'sr{i}.{name}'] = ([v * lw for v in value]
if 'loss' in name else value)
if i + 1 in range(self.num_refine_stages):
rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
return losses
[docs] def simple_test(self, img, img_meta, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes. \
The outer list corresponds to each image. The inner list \
corresponds to each class.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
rois = self.bbox_head.filter_bboxes(*outs)
# rois: list(indexed by images) of list(indexed by levels)
for i in range(self.num_refine_stages):
x_refine = self.feat_refine_module[i](x, rois)
outs = self.refine_head[i](x_refine)
if i + 1 in range(self.num_refine_stages):
rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois)
bbox_results = [
rbbox2result(det_bboxes, det_labels,
self.refine_head[-1].num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
[docs] def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass