Shortcuts

Source code for mmrotate.datasets.pipelines.transforms

# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import mmcv
import numpy as np
from mmdet.datasets.pipelines.transforms import RandomCrop, RandomFlip, Resize

from mmrotate.core import norm_angle, obb2poly_np, poly2obb_np
from ..builder import ROTATED_PIPELINES


[docs]@ROTATED_PIPELINES.register_module() class RResize(Resize): """Resize images & rotated bbox Inherit Resize pipeline class to handle rotated bboxes. Args: img_scale (tuple or list[tuple]): Images scales for resizing. multiscale_mode (str): Either "range" or "value". ratio_range (tuple[float]): (min_ratio, max_ratio). """ def __init__(self, img_scale=None, multiscale_mode='range', ratio_range=None): super(RResize, self).__init__( img_scale=img_scale, multiscale_mode=multiscale_mode, ratio_range=ratio_range, keep_ratio=True) def _resize_bboxes(self, results): """Resize bounding boxes with ``results['scale_factor']``.""" for key in results.get('bbox_fields', []): bboxes = results[key] orig_shape = bboxes.shape bboxes = bboxes.reshape((-1, 5)) w_scale, h_scale, _, _ = results['scale_factor'] bboxes[:, 0] *= w_scale bboxes[:, 1] *= h_scale bboxes[:, 2:4] *= np.sqrt(w_scale * h_scale) results[key] = bboxes.reshape(orig_shape)
[docs]@ROTATED_PIPELINES.register_module() class RRandomFlip(RandomFlip): """ Args: flip_ratio (float | list[float], optional): The flipping probability. Default: None. direction(str | list[str], optional): The flipping direction. Options are 'horizontal', 'vertical', 'diagonal'. version (str, optional): Angle representations. Defaults to 'oc'. """ def __init__(self, flip_ratio=None, direction='horizontal', version='oc'): self.version = version super(RRandomFlip, self).__init__(flip_ratio, direction)
[docs] def bbox_flip(self, bboxes, img_shape, direction): """Flip bboxes horizontally or vertically. Args: bboxes(ndarray): shape (..., 5*k) img_shape(tuple): (height, width) Returns: numpy.ndarray: Flipped bounding boxes. """ assert bboxes.shape[-1] % 5 == 0 orig_shape = bboxes.shape bboxes = bboxes.reshape((-1, 5)) flipped = bboxes.copy() if direction == 'horizontal': flipped[:, 0] = img_shape[1] - bboxes[:, 0] - 1 elif direction == 'vertical': flipped[:, 1] = img_shape[0] - bboxes[:, 1] - 1 elif direction == 'diagonal': flipped[:, 0] = img_shape[1] - bboxes[:, 0] - 1 flipped[:, 1] = img_shape[0] - bboxes[:, 1] - 1 return flipped.reshape(orig_shape) else: raise ValueError(f'Invalid flipping direction "{direction}"') if self.version == 'oc': rotated_flag = (bboxes[:, 4] != np.pi / 2) flipped[rotated_flag, 4] = np.pi / 2 - bboxes[rotated_flag, 4] flipped[rotated_flag, 2] = bboxes[rotated_flag, 3] flipped[rotated_flag, 3] = bboxes[rotated_flag, 2] else: flipped[:, 4] = norm_angle(np.pi - bboxes[:, 4], self.version) return flipped.reshape(orig_shape)
[docs]@ROTATED_PIPELINES.register_module() class PolyRandomRotate(object): """Rotate img & bbox. Reference: https://github.com/hukaixuan19970627/OrientedRepPoints_DOTA Args: rotate_ratio (float, optional): The rotating probability. Default: 0.5. mode (str, optional) : Indicates whether the angle is chosen in a random range (mode='range') or in a preset list of angles (mode='value'). Defaults to 'range'. angles_range(int|list[int], optional): The range of angles. If mode='range', angle_ranges is an int and the angle is chosen in (-angles_range, +angles_ranges). If mode='value', angles_range is a non-empty list of int and the angle is chosen in angles_range. Defaults to 180 as default mode is 'range'. auto_bound(bool, optional): whether to find the new width and height bounds. rect_classes (None|list, optional): Specifies classes that needs to be rotated by a multiple of 90 degrees. version (str, optional): Angle representations. Defaults to 'le90'. """ def __init__(self, rotate_ratio=0.5, mode='range', angles_range=180, auto_bound=False, rect_classes=None, version='le90'): self.rotate_ratio = rotate_ratio self.auto_bound = auto_bound assert mode in ['range', 'value'], \ f"mode is supposed to be 'range' or 'value', but got {mode}." if mode == 'range': assert isinstance(angles_range, int), \ "mode 'range' expects angle_range to be an int." else: assert mmcv.is_seq_of(angles_range, int) and len(angles_range), \ "mode 'value' expects angle_range as a non-empty list of int." self.mode = mode self.angles_range = angles_range self.discrete_range = [90, 180, -90, -180] self.rect_classes = rect_classes self.version = version @property def is_rotate(self): """Randomly decide whether to rotate.""" return np.random.rand() < self.rotate_ratio
[docs] def apply_image(self, img, bound_h, bound_w, interp=cv2.INTER_LINEAR): """ img should be a numpy array, formatted as Height * Width * Nchannels """ if len(img) == 0: return img return cv2.warpAffine( img, self.rm_image, (bound_w, bound_h), flags=interp)
[docs] def apply_coords(self, coords): """ coords should be a N * 2 array-like, containing N couples of (x, y) points """ if len(coords) == 0: return coords coords = np.asarray(coords, dtype=float) return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
[docs] def create_rotation_matrix(self, center, angle, bound_h, bound_w, offset=0): """Create rotation matrix.""" center = (center[0] + offset, center[1] + offset) rm = cv2.getRotationMatrix2D(tuple(center), angle, 1) if self.auto_bound: rot_im_center = cv2.transform(center[None, None, :] + offset, rm)[0, 0, :] new_center = np.array([bound_w / 2, bound_h / 2 ]) + offset - rot_im_center rm[:, 2] += new_center return rm
[docs] def filter_border(self, bboxes, h, w): """Filter the box whose center point is outside or whose side length is less than 5.""" x_ctr, y_ctr = bboxes[:, 0], bboxes[:, 1] w_bbox, h_bbox = bboxes[:, 2], bboxes[:, 3] keep_inds = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h) & \ (w_bbox > 5) & (h_bbox > 5) return keep_inds
def __call__(self, results): """Call function of PolyRandomRotate.""" if not self.is_rotate: results['rotate'] = False angle = 0 else: results['rotate'] = True if self.mode == 'range': angle = self.angles_range * (2 * np.random.rand() - 1) else: i = np.random.randint(len(self.angles_range)) angle = self.angles_range[i] class_labels = results['gt_labels'] for classid in class_labels: if self.rect_classes: if classid in self.rect_classes: np.random.shuffle(self.discrete_range) angle = self.discrete_range[0] break h, w, c = results['img_shape'] img = results['img'] results['rotate_angle'] = angle image_center = np.array((w / 2, h / 2)) abs_cos, abs_sin = abs(np.cos(angle)), abs(np.sin(angle)) if self.auto_bound: bound_w, bound_h = np.rint( [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]).astype(int) else: bound_w, bound_h = w, h self.rm_coords = self.create_rotation_matrix(image_center, angle, bound_h, bound_w) self.rm_image = self.create_rotation_matrix( image_center, angle, bound_h, bound_w, offset=-0.5) img = self.apply_image(img, bound_h, bound_w) results['img'] = img results['img_shape'] = (bound_h, bound_w, c) gt_bboxes = results.get('gt_bboxes', []) labels = results.get('gt_labels', []) gt_bboxes = np.concatenate( [gt_bboxes, np.zeros((gt_bboxes.shape[0], 1))], axis=-1) polys = obb2poly_np(gt_bboxes, self.version)[:, :-1].reshape(-1, 2) polys = self.apply_coords(polys).reshape(-1, 8) gt_bboxes = [] for pt in polys: pt = np.array(pt, dtype=np.float32) obb = poly2obb_np(pt, self.version) \ if poly2obb_np(pt, self.version) is not None\ else [0, 0, 0, 0, 0] gt_bboxes.append(obb) gt_bboxes = np.array(gt_bboxes, dtype=np.float32) keep_inds = self.filter_border(gt_bboxes, bound_h, bound_w) gt_bboxes = gt_bboxes[keep_inds, :] labels = labels[keep_inds] if len(gt_bboxes) == 0: return None results['gt_bboxes'] = gt_bboxes results['gt_labels'] = labels return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(rotate_ratio={self.rotate_ratio}, ' \ f'base_angles={self.base_angles}, ' \ f'angles_range={self.angles_range}, ' \ f'auto_bound={self.auto_bound})' return repr_str
@ROTATED_PIPELINES.register_module() class RRandomCrop(RandomCrop): """Random crop the image & bboxes. The absolute `crop_size` is sampled based on `crop_type` and `image_size`, then the cropped results are generated. Args: crop_size (tuple): The relative ratio or absolute pixels of height and width. crop_type (str, optional): one of "relative_range", "relative", "absolute", "absolute_range". "relative" randomly crops (h * crop_size[0], w * crop_size[1]) part from an input of size (h, w). "relative_range" uniformly samples relative crop size from range [crop_size[0], 1] and [crop_size[1], 1] for height and width respectively. "absolute" crops from an input with absolute size (crop_size[0], crop_size[1]). "absolute_range" uniformly samples crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w in range [crop_size[0], min(w, crop_size[1])]. Default "absolute". allow_negative_crop (bool, optional): Whether to allow a crop that does not contain any bbox area. Default False. Note: - If the image is smaller than the absolute crop size, return the original image. - The keys for bboxes, labels must be aligned. That is, `gt_bboxes` corresponds to `gt_labels`, and `gt_bboxes_ignore` corresponds to `gt_labels_ignore`. - If the crop does not contain any gt-bbox region and `allow_negative_crop` is set to False, skip this image. """ def __init__(self, crop_size, crop_type='absolute', allow_negative_crop=False, version='oc'): self.version = version super(RRandomCrop, self).__init__(crop_size, crop_type, allow_negative_crop) def _crop_data(self, results, crop_size, allow_negative_crop): """Function to randomly crop images, bounding boxes. Args: results (dict): Result dict from loading pipeline. crop_size (tuple): Expected absolute size after cropping, (h, w). allow_negative_crop (bool): Whether to allow a crop that does not contain any bbox area. Default to False. Returns: dict: Randomly cropped results, 'img_shape' key in result dict is updated according to crop size. """ assert crop_size[0] > 0 and crop_size[1] > 0 for key in results.get('bbox_fields', []): assert results[key].shape[-1] % 5 == 0 for key in results.get('img_fields', ['img']): img = results[key] margin_h = max(img.shape[0] - crop_size[0], 0) margin_w = max(img.shape[1] - crop_size[1], 0) offset_h = np.random.randint(0, margin_h + 1) offset_w = np.random.randint(0, margin_w + 1) crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] # crop the image img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] img_shape = img.shape results[key] = img results['img_shape'] = img_shape height, width, _ = img_shape # crop bboxes accordingly and clip to the image boundary for key in results.get('bbox_fields', []): # e.g. gt_bboxes and gt_bboxes_ignore bbox_offset = np.array([offset_w, offset_h, 0, 0, 0], dtype=np.float32) bboxes = results[key] - bbox_offset valid_inds = (bboxes[:, 0] >= 0) & (bboxes[:, 0] < width) & (bboxes[:, 1] >= 0) & ( bboxes[:, 1] < height) # If the crop does not contain any gt-bbox area and # allow_negative_crop is False, skip this image. if (key == 'gt_bboxes' and not valid_inds.any() and not allow_negative_crop): return None results[key] = bboxes[valid_inds, :] # label fields. e.g. gt_labels and gt_labels_ignore label_key = self.bbox2label.get(key) if label_key in results: results[label_key] = results[label_key][valid_inds] return results
Read the Docs v: v0.3.1
Versions
latest
stable
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.