Shortcuts

Source code for mmrotate.models.losses.spatial_border_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.ops import points_in_polygons

from ..builder import ROTATED_LOSSES


[docs]@ROTATED_LOSSES.register_module() class SpatialBorderLoss(nn.Module): """Spatial Border loss for learning points in Oriented RepPoints. Args: pts (torch.Tensor): point sets with shape (N, 9*2). Default points number in each point set is 9. gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) Returns: loss (torch.Tensor) """ def __init__(self, loss_weight=1.0): super(SpatialBorderLoss, self).__init__() self.loss_weight = loss_weight
[docs] def forward(self, pts, gt_bboxes, weight, *args, **kwargs): loss = self.loss_weight * weighted_spatial_border_loss( pts, gt_bboxes, weight, *args, **kwargs) return loss
def spatial_border_loss(pts, gt_bboxes): """The loss is used to penalize the learning points out of the assigned ground truth boxes (polygon by default). Args: pts (torch.Tensor): point sets with shape (N, 9*2). gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) Returns: loss (torch.Tensor) """ num_gts, num_pointsets = gt_bboxes.size(0), pts.size(0) num_point = int(pts.size(1) / 2.0) loss = pts.new_zeros([0]) if num_gts > 0: inside_flag_list = [] for i in range(num_point): pt = pts[:, (2 * i):(2 * i + 2)].reshape(num_pointsets, 2).contiguous() inside_pt_flag = points_in_polygons(pt, gt_bboxes) inside_pt_flag = torch.diag(inside_pt_flag) inside_flag_list.append(inside_pt_flag) inside_flag = torch.stack(inside_flag_list, dim=1) pts = pts.reshape(-1, num_point, 2) out_border_pts = pts[torch.where(inside_flag == 0)] if out_border_pts.size(0) > 0: corr_gt_boxes = gt_bboxes[torch.where(inside_flag == 0)[0]] corr_gt_boxes_center_x = (corr_gt_boxes[:, 0] + corr_gt_boxes[:, 4]) / 2.0 corr_gt_boxes_center_y = (corr_gt_boxes[:, 1] + corr_gt_boxes[:, 5]) / 2.0 corr_gt_boxes_center = torch.stack( [corr_gt_boxes_center_x, corr_gt_boxes_center_y], dim=1) distance_out_pts = 0.2 * (( (out_border_pts - corr_gt_boxes_center)**2).sum(dim=1).sqrt()) loss = distance_out_pts.sum() / out_border_pts.size(0) return loss def weighted_spatial_border_loss(pts, gt_bboxes, weight, avg_factor=None): """Weghted spatial border loss. Args: pts (torch.Tensor): point sets with shape (N, 9*2). gt_bboxes (torch.Tensor): gt_bboxes with polygon form with shape(N, 8) weight (torch.Tensor): weights for point sets with shape (N) Returns: loss (torch.Tensor) """ weight = weight.unsqueeze(dim=1).repeat(1, 4) assert weight.dim() == 2 if avg_factor is None: avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6 loss = spatial_border_loss(pts, gt_bboxes) return torch.sum(loss)[None] / avg_factor
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.