Source code for mmrotate.core.bbox.assigners.atss_obb_assigner
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import points_in_polygons
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.assigners.base_assigner import BaseAssigner
from ..builder import ROTATED_BBOX_ASSIGNERS
from ..iou_calculators import build_iou_calculator
from ..transforms import obb2poly
[docs]@ROTATED_BBOX_ASSIGNERS.register_module()
class ATSSObbAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `0` or a positive integer
indicating the ground truth index.
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
topk (float): Number of bbox selected in each level.
"""
def __init__(self,
topk,
angle_version='oc',
iou_calculator=dict(type='RBboxOverlaps2D')):
self.topk = topk
self.angle_version = angle_version
self.iou_calculator = build_iou_calculator(iou_calculator)
[docs] def assign(self,
bboxes,
num_level_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None):
"""Assign gt to bboxes.
The assignment is done in following steps
1. compute iou between all bbox (bbox of all pyramid levels) and gt
2. compute center distance between all bbox and gt
3. on each pyramid level, for each gt, select k bbox whose center
are closest to the gt center, so we total select k*l bbox as
candidates for each gt
4. get corresponding iou for the these candidates, and compute the
mean and std, set mean + std as the iou threshold
5. select these candidates whose iou are greater than or equal to
the threshold as positive
6. limit the positive sample's center in gt
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 5).
num_level_bboxes (List): num of bboxes in each level
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 5).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
INF = 100000000
num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0]
# compute iou between all bbox and gt
overlaps = self.iou_calculator(bboxes, gt_bboxes)
# assign 0 by default
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
0,
dtype=torch.long)
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = overlaps.new_zeros((num_bboxes, ))
if num_gt == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = overlaps.new_full((num_bboxes, ),
-1,
dtype=torch.long)
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
# compute center distance between all bbox and gt
# the center of gt and bbox
gt_points = gt_bboxes[:, :2]
bboxes_points = bboxes[:, :2]
distances = (bboxes_points[:, None, :] -
gt_points[None, :, :]).pow(2).sum(-1).sqrt()
# Selecting candidates based on the center distance
candidate_idxs = []
start_idx = 0
for level, bboxes_per_level in enumerate(num_level_bboxes):
# on each pyramid level, for each gt,
# select k bbox whose center are closest to the gt center
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
_, topk_idxs_per_level = distances_per_level.topk(
self.topk, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
start_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
# get corresponding iou for the these candidates, and compute the
# mean and std, set mean + std as the iou threshold
gt_bboxes = obb2poly(gt_bboxes, self.angle_version)
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
overlaps_mean_per_gt = candidate_overlaps.mean(0)
overlaps_std_per_gt = candidate_overlaps.std(0)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
# limit the positive sample's center in gt
inside_flag = points_in_polygons(bboxes_points, gt_bboxes)
is_in_gts = inside_flag[candidate_idxs,
torch.arange(num_gt)].to(is_pos.dtype)
is_pos = is_pos & is_in_gts
for gt_idx in range(num_gt):
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
candidate_idxs = candidate_idxs.view(-1)
# if an anchor box is assigned to multiple gts,
# the one with the highest IoU will be selected.
overlaps_inf = torch.full_like(overlaps,
-INF).t().contiguous().view(-1)
index = candidate_idxs.view(-1)[is_pos.view(-1)]
overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
overlaps_inf = overlaps_inf.view(num_gt, -1).t()
max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
assigned_gt_inds[
max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
pos_inds = torch.nonzero(
assigned_gt_inds > 0, as_tuple=False).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)