Source code for mmrotate.models.losses.spatial_border_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.ops import points_in_polygons
from ..builder import ROTATED_LOSSES
[docs]@ROTATED_LOSSES.register_module()
class SpatialBorderLoss(nn.Module):
"""Spatial Border loss for learning points in Oriented RepPoints.
Args:
pts (torch.Tensor): point sets with shape (N, 9*2).
Default points number in each point set is 9.
gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8)
Returns:
loss (torch.Tensor)
"""
def __init__(self, loss_weight=1.0):
super(SpatialBorderLoss, self).__init__()
self.loss_weight = loss_weight
[docs] def forward(self, pts, gt_bboxes, weight, *args, **kwargs):
loss = self.loss_weight * weighted_spatial_border_loss(
pts, gt_bboxes, weight, *args, **kwargs)
return loss
def spatial_border_loss(pts, gt_bboxes):
"""The loss is used to penalize the learning points out of the assigned
ground truth boxes (polygon by default).
Args:
pts (torch.Tensor): point sets with shape (N, 9*2).
gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8)
Returns:
loss (torch.Tensor)
"""
num_gts, num_pointsets = gt_bboxes.size(0), pts.size(0)
num_point = int(pts.size(1) / 2.0)
loss = pts.new_zeros([0])
if num_gts > 0:
inside_flag_list = []
for i in range(num_point):
pt = pts[:, (2 * i):(2 * i + 2)].reshape(num_pointsets,
2).contiguous()
inside_pt_flag = points_in_polygons(pt, gt_bboxes)
inside_pt_flag = torch.diag(inside_pt_flag)
inside_flag_list.append(inside_pt_flag)
inside_flag = torch.stack(inside_flag_list, dim=1)
pts = pts.reshape(-1, num_point, 2)
out_border_pts = pts[torch.where(inside_flag == 0)]
if out_border_pts.size(0) > 0:
corr_gt_boxes = gt_bboxes[torch.where(inside_flag == 0)[0]]
corr_gt_boxes_center_x = (corr_gt_boxes[:, 0] +
corr_gt_boxes[:, 4]) / 2.0
corr_gt_boxes_center_y = (corr_gt_boxes[:, 1] +
corr_gt_boxes[:, 5]) / 2.0
corr_gt_boxes_center = torch.stack(
[corr_gt_boxes_center_x, corr_gt_boxes_center_y], dim=1)
distance_out_pts = 0.2 * ((
(out_border_pts - corr_gt_boxes_center)**2).sum(dim=1).sqrt())
loss = distance_out_pts.sum() / out_border_pts.size(0)
return loss
def weighted_spatial_border_loss(pts, gt_bboxes, weight, avg_factor=None):
"""Weghted spatial border loss.
Args:
pts (torch.Tensor): point sets with shape (N, 9*2).
gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8)
weight (torch.Tensor): weights for point sets with shape (N)
Returns:
loss (torch.Tensor)
"""
weight = weight.unsqueeze(dim=1).repeat(1, 4)
assert weight.dim() == 2
if avg_factor is None:
avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6
loss = spatial_border_loss(pts, gt_bboxes)
return torch.sum(loss)[None] / avg_factor