Source code for mmrotate.core.bbox.coder.angle_coder
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
from mmdet.core.bbox.coder.base_bbox_coder import BaseBBoxCoder
from ..builder import ROTATED_BBOX_CODERS
[docs]@ROTATED_BBOX_CODERS.register_module()
class CSLCoder(BaseBBoxCoder):
"""Circular Smooth Label Coder.
`Circular Smooth Label (CSL)
<https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40>`_ .
Args:
angle_version (str): Angle definition.
omega (float, optional): Angle discretization granularity.
Default: 1.
window (str, optional): Window function. Default: gaussian.
radius (int/float): window radius, int type for
['triangle', 'rect', 'pulse'], float type for
['gaussian']. Default: 6.
"""
def __init__(self, angle_version, omega=1, window='gaussian', radius=6):
super().__init__()
self.angle_version = angle_version
assert angle_version in ['oc', 'le90', 'le135']
assert window in ['gaussian', 'triangle', 'rect', 'pulse']
self.angle_range = 90 if angle_version == 'oc' else 180
self.angle_offset_dict = {'oc': 0, 'le90': 90, 'le135': 45}
self.angle_offset = self.angle_offset_dict[angle_version]
self.omega = omega
self.window = window
self.radius = radius
self.coding_len = int(self.angle_range // omega)
[docs] def encode(self, angle_targets):
"""Circular Smooth Label Encoder.
Args:
angle_targets (Tensor): Angle offset for each scale level
Has shape (num_anchors * H * W, 1)
Returns:
list[Tensor]: The csl encoding of angle offset for each
scale level. Has shape (num_anchors * H * W, coding_len)
"""
# radius to degree
angle_targets_deg = angle_targets * (180 / math.pi)
# empty label
smooth_label = torch.zeros_like(angle_targets).repeat(
1, self.coding_len)
angle_targets_deg = (angle_targets_deg +
self.angle_offset) / self.omega
# Float to Int
angle_targets_long = angle_targets_deg.long()
if self.window == 'pulse':
radius_range = angle_targets_long % self.coding_len
smooth_value = 1.0
elif self.window == 'rect':
base_radius_range = torch.arange(
-self.radius, self.radius, device=angle_targets_long.device)
radius_range = (base_radius_range +
angle_targets_long) % self.coding_len
smooth_value = 1.0
elif self.window == 'triangle':
base_radius_range = torch.arange(
-self.radius, self.radius, device=angle_targets_long.device)
radius_range = (base_radius_range +
angle_targets_long) % self.coding_len
smooth_value = 1.0 - torch.abs(
(1 / self.radius) * base_radius_range)
elif self.window == 'gaussian':
base_radius_range = torch.arange(
-self.angle_range // 2,
self.angle_range // 2,
device=angle_targets_long.device)
radius_range = (base_radius_range +
angle_targets_long) % self.coding_len
smooth_value = torch.exp(-torch.pow(base_radius_range, 2) /
(2 * self.radius**2))
else:
raise NotImplementedError
if isinstance(smooth_value, torch.Tensor):
smooth_value = smooth_value.unsqueeze(0).repeat(
smooth_label.size(0), 1)
return smooth_label.scatter(1, radius_range, smooth_value)
[docs] def decode(self, angle_preds):
"""Circular Smooth Label Decoder.
Args:
angle_preds (Tensor): The csl encoding of angle offset
for each scale level.
Has shape (num_anchors * H * W, coding_len)
Returns:
list[Tensor]: Angle offset for each scale level.
Has shape (num_anchors * H * W, 1)
"""
angle_cls_inds = torch.argmax(angle_preds, dim=1)
angle_pred = ((angle_cls_inds + 0.5) *
self.omega) % self.angle_range - self.angle_offset
return angle_pred * (math.pi / 180)