Shortcuts

Source code for mmrotate.models.losses.spatial_border_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.ops import points_in_polygons

from ..builder import ROTATED_LOSSES


[docs]@ROTATED_LOSSES.register_module() class SpatialBorderLoss(nn.Module): """Spatial Border loss for learning points in Oriented RepPoints. Args: pts (torch.Tensor): point sets with shape (N, 9*2). Default points number in each point set is 9. gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) Returns: loss (torch.Tensor) """ def __init__(self, loss_weight=1.0): super(SpatialBorderLoss, self).__init__() self.loss_weight = loss_weight
[docs] def forward(self, pts, gt_bboxes, weight, *args, **kwargs): loss = self.loss_weight * weighted_spatial_border_loss( pts, gt_bboxes, weight, *args, **kwargs) return loss
def spatial_border_loss(pts, gt_bboxes): """The loss is used to penalize the learning points out of the assigned ground truth boxes (polygon by default). Args: pts (torch.Tensor): point sets with shape (N, 9*2). gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) Returns: loss (torch.Tensor) """ num_gts, num_pointsets = gt_bboxes.size(0), pts.size(0) num_point = int(pts.size(1) / 2.0) loss = pts.new_zeros([0]) if num_gts > 0: inside_flag_list = [] for i in range(num_point): pt = pts[:, (2 * i):(2 * i + 2)].reshape(num_pointsets, 2).contiguous() inside_pt_flag = points_in_polygons(pt, gt_bboxes) inside_pt_flag = torch.diag(inside_pt_flag) inside_flag_list.append(inside_pt_flag) inside_flag = torch.stack(inside_flag_list, dim=1) pts = pts.reshape(-1, num_point, 2) out_border_pts = pts[torch.where(inside_flag == 0)] if out_border_pts.size(0) > 0: corr_gt_boxes = gt_bboxes[torch.where(inside_flag == 0)[0]] corr_gt_boxes_center_x = (corr_gt_boxes[:, 0] + corr_gt_boxes[:, 4]) / 2.0 corr_gt_boxes_center_y = (corr_gt_boxes[:, 1] + corr_gt_boxes[:, 5]) / 2.0 corr_gt_boxes_center = torch.stack( [corr_gt_boxes_center_x, corr_gt_boxes_center_y], dim=1) distance_out_pts = 0.2 * (( (out_border_pts - corr_gt_boxes_center)**2).sum(dim=1).sqrt()) loss = distance_out_pts.sum() / out_border_pts.size(0) return loss def weighted_spatial_border_loss(pts, gt_bboxes, weight, avg_factor=None): """Weghted spatial border loss. Args: pts (torch.Tensor): point sets with shape (N, 9*2). gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) weight (torch.Tensor): weights for point sets with shape (N) Returns: loss (torch.Tensor) """ weight = weight.unsqueeze(dim=1).repeat(1, 4) assert weight.dim() == 2 if avg_factor is None: avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6 loss = spatial_border_loss(pts, gt_bboxes) return torch.sum(loss)[None] / avg_factor
Read the Docs v: v0.3.2
Versions
latest
stable
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.