Shortcuts

Source code for mmrotate.core.bbox.assigners.max_convex_iou_assigner

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import convex_iou
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.assigners.base_assigner import BaseAssigner

from ..builder import ROTATED_BBOX_ASSIGNERS


[docs]@ROTATED_BBOX_ASSIGNERS.register_module() class MaxConvexIoUAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. Each proposals will be assigned with `-1`, or a semi-positive integer indicating the ground truth index. - -1: negative sample, no assigned gt - semi-positive integer: positive sample, index (0-based) of assigned gt Args: pos_iou_thr (float): IoU threshold for positive bboxes. neg_iou_thr (float or tuple): IoU threshold for negative bboxes. min_pos_iou (float): Minimum iou for a bbox to be considered as a positive bbox. Positive samples can have smaller IoU than pos_iou_thr due to the 4th step (assign max IoU sample to each gt). gt_max_assign_all (bool): Whether to assign all bboxes with the same highest overlap with some gt to that gt. ignore_iof_thr (float): IoF threshold for ignoring bboxes (if `gt_bboxes_ignore` is specified). Negative values mean not ignoring any bboxes. ignore_wrt_candidates (bool): Whether to compute the iof between `bboxes` and `gt_bboxes_ignore`, or the contrary. gpu_assign_thr (int): The upper bound of the number of GT for GPU assign. When the number of gt is above this threshold, will assign on CPU device. Negative values mean not assign on CPU. """ def __init__(self, pos_iou_thr, neg_iou_thr, min_pos_iou=.0, gt_max_assign_all=True, ignore_iof_thr=-1, ignore_wrt_candidates=True, gpu_assign_thr=-1): self.pos_iou_thr = pos_iou_thr self.neg_iou_thr = neg_iou_thr self.min_pos_iou = min_pos_iou self.gt_max_assign_all = gt_max_assign_all self.ignore_iof_thr = ignore_iof_thr self.ignore_wrt_candidates = ignore_wrt_candidates self.gpu_assign_thr = gpu_assign_thr
[docs] def assign( self, points, gt_rbboxes, overlaps, gt_rbboxes_ignore=None, gt_labels=None, ): """Assign gt to bboxes. The assignment is done in following steps 1. compute iou between all bbox (bbox of all pyramid levels) and gt 2. compute center distance between all bbox and gt 3. on each pyramid level, for each gt, select k bbox whose center are closest to the gt center, so we total select k*l bbox as candidates for each gt 4. get corresponding iou for the these candidates, and compute the mean and std, set mean + std as the iou threshold 5. select these candidates whose iou are greater than or equal to the threshold as positive 6. limit the positive sample's center in gt Args: points (torch.Tensor): Points to be assigned, shape(n, 18). gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes, shape(k, n). gt_rbboxes_ignore (Tensor, optional): Ground truth polygons that are labelled as `ignored`, e.g., crowd boxes in COCO. gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). Returns: :obj:`AssignResult`: The assign result. """ assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( gt_rbboxes.shape[0] > self.gpu_assign_thr) else False # compute overlap and assign gt on CPU when number of GT is large if assign_on_cpu: device = points.device bboxes = points.cpu() gt_rbboxes = gt_rbboxes.cpu() if gt_rbboxes_ignore is not None: gt_rbboxes_ignore = gt_rbboxes_ignore.cpu() if gt_labels is not None: gt_labels = gt_labels.cpu() if overlaps is None: overlaps = self.convex_overlaps(gt_rbboxes, points) if (self.ignore_iof_thr > 0 and gt_rbboxes_ignore is not None and gt_rbboxes_ignore.numel() > 0 and bboxes.numel() > 0): if self.ignore_wrt_candidates: ignore_overlaps = self.convex_overlaps( bboxes, gt_rbboxes_ignore, mode='iof') ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) else: ignore_overlaps = self.convex_overlaps( gt_rbboxes_ignore, bboxes, mode='iof') ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) if assign_on_cpu: assign_result.gt_inds = assign_result.gt_inds.to(device) assign_result.max_overlaps = assign_result.max_overlaps.to(device) if assign_result.labels is not None: assign_result.labels = assign_result.labels.to(device) return assign_result
[docs] def assign_wrt_overlaps(self, overlaps, gt_labels=None): """Assign w.r.t. the overlaps of bboxes with gts. Args: overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes, shape(k, n). gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ). Returns: :obj:`AssignResult`: The assign result. """ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1) # 1. assign -1 by default assigned_gt_inds = overlaps.new_full((num_bboxes, ), -1, dtype=torch.long) if num_gts == 0 or num_bboxes == 0: # No ground truth or boxes, return empty assignment max_overlaps = overlaps.new_zeros((num_bboxes, )) if num_gts == 0: # No truth, assign everything to background assigned_gt_inds[:] = 0 if gt_labels is None: assigned_labels = None else: assigned_labels = overlaps.new_full((num_bboxes, ), -1, dtype=torch.long) return AssignResult( num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) # for each anchor, which gt best overlaps with it # for each anchor, the max iou of all gts max_overlaps, argmax_overlaps = overlaps.max(dim=0) # for each gt, which anchor best overlaps with it # for each gt, the max iou of all proposals gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) # 2. assign negative: below if isinstance(self.neg_iou_thr, float): assigned_gt_inds[(max_overlaps >= 0) & (max_overlaps < self.neg_iou_thr)] = 0 elif isinstance(self.neg_iou_thr, tuple): assert len(self.neg_iou_thr) == 2 assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0]) & (max_overlaps < self.neg_iou_thr[1])] = 0 # 3. assign positive: above positive IoU threshold pos_inds = max_overlaps >= self.pos_iou_thr assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 # 4. assign fg: for each gt, proposals with highest IoU for i in range(num_gts): if gt_max_overlaps[i] >= self.min_pos_iou: if self.gt_max_assign_all: max_iou_inds = overlaps[i, :] == gt_max_overlaps[i] assigned_gt_inds[max_iou_inds] = i + 1 else: assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 if gt_labels is not None: assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) pos_inds = torch.nonzero( assigned_gt_inds > 0, as_tuple=False).squeeze() if pos_inds.numel() > 0: assigned_labels[pos_inds] = gt_labels[ assigned_gt_inds[pos_inds] - 1] else: assigned_labels = None return AssignResult( num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
[docs] def convex_overlaps(self, gt_rbboxes, points): """Compute overlaps between polygons and points. Args: gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). points (torch.Tensor): Points to be assigned, shape(n, 18). Returns: overlaps (torch.Tensor): Overlaps between k gt_bboxes and n \ bboxes, shape(k, n). """ overlaps = convex_iou(points, gt_rbboxes) overlaps = overlaps.transpose(1, 0) return overlaps
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.