Shortcuts

Source code for mmrotate.core.bbox.coder.angle_coder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
from mmdet.core.bbox.coder.base_bbox_coder import BaseBBoxCoder

from ..builder import ROTATED_BBOX_CODERS


[docs]@ROTATED_BBOX_CODERS.register_module() class CSLCoder(BaseBBoxCoder): """Circular Smooth Label Coder. `Circular Smooth Label (CSL) <https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40>`_ . Args: angle_version (str): Angle definition. omega (float, optional): Angle discretization granularity. Default: 1. window (str, optional): Window function. Default: gaussian. radius (int/float): window radius, int type for ['triangle', 'rect', 'pulse'], float type for ['gaussian']. Default: 6. """ def __init__(self, angle_version, omega=1, window='gaussian', radius=6): super().__init__() self.angle_version = angle_version assert angle_version in ['oc', 'le90', 'le135'] assert window in ['gaussian', 'triangle', 'rect', 'pulse'] self.angle_range = 90 if angle_version == 'oc' else 180 self.angle_offset_dict = {'oc': 0, 'le90': 90, 'le135': 45} self.angle_offset = self.angle_offset_dict[angle_version] self.omega = omega self.window = window self.radius = radius self.coding_len = int(self.angle_range // omega)
[docs] def encode(self, angle_targets): """Circular Smooth Label Encoder. Args: angle_targets (Tensor): Angle offset for each scale level Has shape (num_anchors * H * W, 1) Returns: list[Tensor]: The csl encoding of angle offset for each scale level. Has shape (num_anchors * H * W, coding_len) """ # radius to degree angle_targets_deg = angle_targets * (180 / math.pi) # empty label smooth_label = torch.zeros_like(angle_targets).repeat( 1, self.coding_len) angle_targets_deg = (angle_targets_deg + self.angle_offset) / self.omega # Float to Int angle_targets_long = angle_targets_deg.long() if self.window == 'pulse': radius_range = angle_targets_long % self.coding_len smooth_value = 1.0 elif self.window == 'rect': base_radius_range = torch.arange( -self.radius, self.radius, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = 1.0 elif self.window == 'triangle': base_radius_range = torch.arange( -self.radius, self.radius, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = 1.0 - torch.abs( (1 / self.radius) * base_radius_range) elif self.window == 'gaussian': base_radius_range = torch.arange( -self.angle_range // 2, self.angle_range // 2, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = torch.exp(-torch.pow(base_radius_range, 2) / (2 * self.radius**2)) else: raise NotImplementedError if isinstance(smooth_value, torch.Tensor): smooth_value = smooth_value.unsqueeze(0).repeat( smooth_label.size(0), 1) return smooth_label.scatter(1, radius_range, smooth_value)
[docs] def decode(self, angle_preds): """Circular Smooth Label Decoder. Args: angle_preds (Tensor): The csl encoding of angle offset for each scale level. Has shape (num_anchors * H * W, coding_len) Returns: list[Tensor]: Angle offset for each scale level. Has shape (num_anchors * H * W, 1) """ angle_cls_inds = torch.argmax(angle_preds, dim=1) angle_pred = ((angle_cls_inds + 0.5) * self.omega) % self.angle_range - self.angle_offset return angle_pred * (math.pi / 180)
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.