Shortcuts

Source code for mmrotate.core.anchor.anchor_generator

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.utils import to_2tuple
from mmdet.core.anchor import AnchorGenerator

from .builder import ROTATED_ANCHOR_GENERATORS


[docs]@ROTATED_ANCHOR_GENERATORS.register_module() class RotatedAnchorGenerator(AnchorGenerator): """Fake rotate anchor generator for 2D anchor-based detectors. Horizontal bounding box represented by (x,y,w,h,theta). """
[docs] def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device='cuda'): """Generate grid anchors of a single level. Note: This function is usually called by method ``self.grid_priors``. Args: featmap_size (tuple[int]): Size of the feature maps. level_idx (int): The index of corresponding feature map level. dtype (obj:`torch.dtype`): Date type of points.Defaults to ``torch.float32``. device (str, optional): The device the tensor will be put on. Defaults to 'cuda'. Returns: torch.Tensor: Anchors in the overall feature maps. """ anchors = super(RotatedAnchorGenerator, self).single_level_grid_priors( featmap_size, level_idx, dtype=dtype, device=device) # The correct usage is: # from ..bbox.transforms import hbb2obb # anchors = hbb2obb(anchors, self.angle_version) # instead of rudely setting the angle to all 0. # However, the experiment shows that the performance has decreased. num_anchors = anchors.size(0) xy = (anchors[:, 2:] + anchors[:, :2]) / 2 wh = anchors[:, 2:] - anchors[:, :2] theta = xy.new_zeros((num_anchors, 1)) anchors = torch.cat([xy, wh, theta], axis=1) return anchors
[docs]@ROTATED_ANCHOR_GENERATORS.register_module() class PseudoAnchorGenerator(AnchorGenerator): """Non-Standard pseudo anchor generator that is used to generate valid flags only!""" def __init__(self, strides): self.strides = [to_2tuple(stride) for stride in strides] @property def num_base_anchors(self): """list[int]: total number of base anchors in a feature grid""" return [1 for _ in self.strides]
[docs] def single_level_grid_anchors(self, featmap_sizes, device='cuda'): """Calling its grid_anchors() method will raise NotImplementedError!""" raise NotImplementedError
def __repr__(self): indent_str = ' ' repr_str = self.__class__.__name__ + '(\n' repr_str += f'{indent_str}strides={self.strides})' return repr_str
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.