Source code for mmrotate.core.anchor.anchor_generator
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.utils import to_2tuple
from mmdet.core.anchor import AnchorGenerator
from .builder import ROTATED_ANCHOR_GENERATORS
[docs]@ROTATED_ANCHOR_GENERATORS.register_module()
class RotatedAnchorGenerator(AnchorGenerator):
"""Fake rotate anchor generator for 2D anchor-based detectors.
Horizontal bounding box represented by (x,y,w,h,theta).
"""
[docs] def single_level_grid_priors(self,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda'):
"""Generate grid anchors of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int]): Size of the feature maps.
level_idx (int): The index of corresponding feature map level.
dtype (obj:`torch.dtype`): Date type of points.Defaults to
``torch.float32``.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""
anchors = super(RotatedAnchorGenerator, self).single_level_grid_priors(
featmap_size, level_idx, dtype=dtype, device=device)
# The correct usage is:
# from ..bbox.transforms import hbb2obb
# anchors = hbb2obb(anchors, self.angle_version)
# instead of rudely setting the angle to all 0.
# However, the experiment shows that the performance has decreased.
num_anchors = anchors.size(0)
xy = (anchors[:, 2:] + anchors[:, :2]) / 2
wh = anchors[:, 2:] - anchors[:, :2]
theta = xy.new_zeros((num_anchors, 1))
anchors = torch.cat([xy, wh, theta], axis=1)
return anchors
[docs]@ROTATED_ANCHOR_GENERATORS.register_module()
class PseudoAnchorGenerator(AnchorGenerator):
"""Non-Standard pseudo anchor generator that is used to generate valid
flags only!"""
def __init__(self, strides):
self.strides = [to_2tuple(stride) for stride in strides]
@property
def num_base_anchors(self):
"""list[int]: total number of base anchors in a feature grid"""
return [1 for _ in self.strides]
[docs] def single_level_grid_anchors(self, featmap_sizes, device='cuda'):
"""Calling its grid_anchors() method will raise NotImplementedError!"""
raise NotImplementedError
def __repr__(self):
indent_str = ' '
repr_str = self.__class__.__name__ + '(\n'
repr_str += f'{indent_str}strides={self.strides})'
return repr_str