Shortcuts

Source code for mmrotate.core.bbox.coder.angle_coder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
from mmdet.core.bbox.coder.base_bbox_coder import BaseBBoxCoder

from ..builder import ROTATED_BBOX_CODERS


[docs]@ROTATED_BBOX_CODERS.register_module() class CSLCoder(BaseBBoxCoder): """Circular Smooth Label Coder. `Circular Smooth Label (CSL) <https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40>`_ . Args: angle_version (str): Angle definition. omega (float, optional): Angle discretization granularity. Default: 1. window (str, optional): Window function. Default: gaussian. radius (int/float): window radius, int type for ['triangle', 'rect', 'pulse'], float type for ['gaussian']. Default: 6. """ def __init__(self, angle_version, omega=1, window='gaussian', radius=6): super().__init__() self.angle_version = angle_version assert angle_version in ['oc', 'le90', 'le135'] assert window in ['gaussian', 'triangle', 'rect', 'pulse'] self.angle_range = 90 if angle_version == 'oc' else 180 self.angle_offset_dict = {'oc': 0, 'le90': 90, 'le135': 45} self.angle_offset = self.angle_offset_dict[angle_version] self.omega = omega self.window = window self.radius = radius self.coding_len = int(self.angle_range // omega)
[docs] def encode(self, angle_targets): """Circular Smooth Label Encoder. Args: angle_targets (Tensor): Angle offset for each scale level Has shape (num_anchors * H * W, 1) Returns: list[Tensor]: The csl encoding of angle offset for each scale level. Has shape (num_anchors * H * W, coding_len) """ # radius to degree angle_targets_deg = angle_targets * (180 / math.pi) # empty label smooth_label = torch.zeros_like(angle_targets).repeat( 1, self.coding_len) angle_targets_deg = (angle_targets_deg + self.angle_offset) / self.omega # Float to Int angle_targets_long = angle_targets_deg.long() if self.window == 'pulse': radius_range = angle_targets_long % self.coding_len smooth_value = 1.0 elif self.window == 'rect': base_radius_range = torch.arange( -self.radius, self.radius, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = 1.0 elif self.window == 'triangle': base_radius_range = torch.arange( -self.radius, self.radius, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = 1.0 - torch.abs( (1 / self.radius) * base_radius_range) elif self.window == 'gaussian': base_radius_range = torch.arange( -self.angle_range // 2, self.angle_range // 2, device=angle_targets_long.device) radius_range = (base_radius_range + angle_targets_long) % self.coding_len smooth_value = torch.exp(-torch.pow(base_radius_range, 2) / (2 * self.radius**2)) else: raise NotImplementedError if isinstance(smooth_value, torch.Tensor): smooth_value = smooth_value.unsqueeze(0).repeat( smooth_label.size(0), 1) return smooth_label.scatter(1, radius_range, smooth_value)
[docs] def decode(self, angle_preds): """Circular Smooth Label Decoder. Args: angle_preds (Tensor): The csl encoding of angle offset for each scale level. Has shape (num_anchors * H * W, coding_len) Returns: list[Tensor]: Angle offset for each scale level. Has shape (num_anchors * H * W, 1) """ angle_cls_inds = torch.argmax(angle_preds, dim=1) angle_pred = ((angle_cls_inds + 0.5) * self.omega) % self.angle_range - self.angle_offset return angle_pred * (math.pi / 180)
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.