Shortcuts

Source code for mmrotate.core.bbox.samplers.rotate_random_sampler

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core.bbox.samplers.base_sampler import BaseSampler
from mmdet.core.bbox.samplers.sampling_result import SamplingResult

from ..builder import ROTATED_BBOX_SAMPLERS


[docs]@ROTATED_BBOX_SAMPLERS.register_module() class RRandomSampler(BaseSampler): """Random sampler. Args: num (int): Number of samples pos_fraction (float): Fraction of positive samples neg_pos_up (int, optional): Upper bound number of negative and positive samples. Defaults to -1. add_gt_as_proposals (bool, optional): Whether to add ground truth boxes as proposals. Defaults to True. """ def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): from mmdet.core.bbox import demodata super(RRandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) self.rng = demodata.ensure_rng(kwargs.get('rng', None))
[docs] def random_choice(self, gallery, num): """Random select some elements from the gallery. If `gallery` is a Tensor, the returned indices will be a Tensor; If `gallery` is a ndarray or list, the returned indices will be a ndarray. Args: gallery (Tensor | ndarray | list): indices pool. num (int): expected sample num. Returns: Tensor or ndarray: sampled indices. """ assert len(gallery) >= num is_tensor = isinstance(gallery, torch.Tensor) if not is_tensor: gallery = torch.tensor( gallery, dtype=torch.long, device=torch.cuda.current_device()) perm = torch.randperm(gallery.numel(), device=gallery.device)[:num] rand_inds = gallery[perm] if not is_tensor: rand_inds = rand_inds.cpu().numpy() return rand_inds
def _sample_pos(self, assign_result, num_expected, **kwargs): """Randomly sample some positive samples.""" pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) if pos_inds.numel() != 0: pos_inds = pos_inds.squeeze(1) if pos_inds.numel() <= num_expected: return pos_inds else: return self.random_choice(pos_inds, num_expected) def _sample_neg(self, assign_result, num_expected, **kwargs): """Randomly sample some negative samples.""" neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) if neg_inds.numel() != 0: neg_inds = neg_inds.squeeze(1) if len(neg_inds) <= num_expected: return neg_inds else: return self.random_choice(neg_inds, num_expected)
[docs] def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): """Sample positive and negative bboxes. This is a simple implementation of bbox sampling given candidates, assigning results and ground truth bboxes. Args: assign_result (:obj:`AssignResult`): Bbox assigning results. bboxes (torch.Tensor): Boxes to be sampled from. gt_bboxes (torch.Tensor): Ground truth bboxes. gt_labels (Tensor, optional): Class labels of ground truth bboxes. Returns: :obj:`SamplingResult`: Sampling result. Example: >>> from mmdet.core.bbox import RandomSampler >>> from mmdet.core.bbox import AssignResult >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes >>> rng = ensure_rng(None) >>> assign_result = AssignResult.random(rng=rng) >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) >>> gt_labels = None >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, >>> add_gt_as_proposals=False) >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) """ if len(bboxes.shape) < 2: bboxes = bboxes[None, :] bboxes = bboxes[:, :5] gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: if gt_labels is None: raise ValueError( 'gt_labels must be given when add_gt_as_proposals is True') bboxes = torch.cat([gt_bboxes, bboxes], dim=0) assign_result.add_gt_(gt_labels) gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) gt_flags = torch.cat([gt_ones, gt_flags]) num_expected_pos = int(self.num * self.pos_fraction) pos_inds = self.pos_sampler._sample_pos( assign_result, num_expected_pos, bboxes=bboxes, **kwargs) # We found that sampled indices have duplicated items occasionally. # (may be a bug of PyTorch) pos_inds = pos_inds.unique() num_sampled_pos = pos_inds.numel() num_expected_neg = self.num - num_sampled_pos if self.neg_pos_ub >= 0: _pos = max(1, num_sampled_pos) neg_upper_bound = int(self.neg_pos_ub * _pos) if num_expected_neg > neg_upper_bound: num_expected_neg = neg_upper_bound neg_inds = self.neg_sampler._sample_neg( assign_result, num_expected_neg, bboxes=bboxes, **kwargs) neg_inds = neg_inds.unique() sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) return sampling_result
Read the Docs v: stable
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.