Shortcuts

Source code for mmrotate.core.bbox.samplers.rotate_random_sampler

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core.bbox.samplers.base_sampler import BaseSampler
from mmdet.core.bbox.samplers.sampling_result import SamplingResult

from ..builder import ROTATED_BBOX_SAMPLERS


[docs]@ROTATED_BBOX_SAMPLERS.register_module() class RRandomSampler(BaseSampler): """Random sampler. Args: num (int): Number of samples pos_fraction (float): Fraction of positive samples neg_pos_up (int, optional): Upper bound number of negative and positive samples. Defaults to -1. add_gt_as_proposals (bool, optional): Whether to add ground truth boxes as proposals. Defaults to True. """ def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): from mmdet.core.bbox import demodata super(RRandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) self.rng = demodata.ensure_rng(kwargs.get('rng', None))
[docs] def random_choice(self, gallery, num): """Random select some elements from the gallery. If `gallery` is a Tensor, the returned indices will be a Tensor; If `gallery` is a ndarray or list, the returned indices will be a ndarray. Args: gallery (Tensor | ndarray | list): indices pool. num (int): expected sample num. Returns: Tensor or ndarray: sampled indices. """ assert len(gallery) >= num is_tensor = isinstance(gallery, torch.Tensor) if not is_tensor: gallery = torch.tensor( gallery, dtype=torch.long, device=torch.cuda.current_device()) perm = torch.randperm(gallery.numel(), device=gallery.device)[:num] rand_inds = gallery[perm] if not is_tensor: rand_inds = rand_inds.cpu().numpy() return rand_inds
def _sample_pos(self, assign_result, num_expected, **kwargs): """Randomly sample some positive samples.""" pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) if pos_inds.numel() != 0: pos_inds = pos_inds.squeeze(1) if pos_inds.numel() <= num_expected: return pos_inds else: return self.random_choice(pos_inds, num_expected) def _sample_neg(self, assign_result, num_expected, **kwargs): """Randomly sample some negative samples.""" neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) if neg_inds.numel() != 0: neg_inds = neg_inds.squeeze(1) if len(neg_inds) <= num_expected: return neg_inds else: return self.random_choice(neg_inds, num_expected)
[docs] def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): """Sample positive and negative bboxes. This is a simple implementation of bbox sampling given candidates, assigning results and ground truth bboxes. Args: assign_result (:obj:`AssignResult`): Bbox assigning results. bboxes (torch.Tensor): Boxes to be sampled from. gt_bboxes (torch.Tensor): Ground truth bboxes. gt_labels (Tensor, optional): Class labels of ground truth bboxes. Returns: :obj:`SamplingResult`: Sampling result. Example: >>> from mmdet.core.bbox import RandomSampler >>> from mmdet.core.bbox import AssignResult >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes >>> rng = ensure_rng(None) >>> assign_result = AssignResult.random(rng=rng) >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) >>> gt_labels = None >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, >>> add_gt_as_proposals=False) >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) """ if len(bboxes.shape) < 2: bboxes = bboxes[None, :] bboxes = bboxes[:, :5] gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: if gt_labels is None: raise ValueError( 'gt_labels must be given when add_gt_as_proposals is True') bboxes = torch.cat([gt_bboxes, bboxes], dim=0) assign_result.add_gt_(gt_labels) gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) gt_flags = torch.cat([gt_ones, gt_flags]) num_expected_pos = int(self.num * self.pos_fraction) pos_inds = self.pos_sampler._sample_pos( assign_result, num_expected_pos, bboxes=bboxes, **kwargs) # We found that sampled indices have duplicated items occasionally. # (may be a bug of PyTorch) pos_inds = pos_inds.unique() num_sampled_pos = pos_inds.numel() num_expected_neg = self.num - num_sampled_pos if self.neg_pos_ub >= 0: _pos = max(1, num_sampled_pos) neg_upper_bound = int(self.neg_pos_ub * _pos) if num_expected_neg > neg_upper_bound: num_expected_neg = neg_upper_bound neg_inds = self.neg_sampler._sample_neg( assign_result, num_expected_neg, bboxes=bboxes, **kwargs) neg_inds = neg_inds.unique() sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) return sampling_result
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.