Shortcuts

Source code for mmrotate.models.losses.kf_iou_loss

# Copyright (c) SJTU. All rights reserved.
import torch
from mmdet.models.losses.utils import weighted_loss
from torch import nn

from ..builder import ROTATED_LOSSES


def xy_wh_r_2_xy_sigma(xywhr):
    """Convert oriented bounding box to 2-D Gaussian distribution.

    Args:
        xywhr (torch.Tensor): rbboxes with shape (N, 5).

    Returns:
        xy (torch.Tensor): center point of 2-D Gaussian distribution
            with shape (N, 2).
        sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
            with shape (N, 2, 2).
    """
    _shape = xywhr.shape
    assert _shape[-1] == 5
    xy = xywhr[..., :2]
    wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
    r = xywhr[..., 4]
    cos_r = torch.cos(r)
    sin_r = torch.sin(r)
    R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
    S = 0.5 * torch.diag_embed(wh)

    sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
                                            1)).reshape(_shape[:-1] + (2, 2))

    return xy, sigma


@weighted_loss
def kfiou_loss(pred,
               target,
               pred_decode=None,
               targets_decode=None,
               fun=None,
               beta=1.0 / 9.0,
               eps=1e-6):
    """Kalman filter IoU loss.

    Args:
        pred (torch.Tensor): Predicted bboxes.
        target (torch.Tensor): Corresponding gt bboxes.
        pred_decode (torch.Tensor): Predicted decode bboxes.
        targets_decode (torch.Tensor): Corresponding gt decode bboxes.
        fun (str): The function applied to distance. Defaults to None.
        beta (float): Defaults to 1.0/9.0.
        eps (float): Defaults to 1e-6.

    Returns:
        loss (torch.Tensor)
    """
    xy_p = pred[:, :2]
    xy_t = target[:, :2]
    _, Sigma_p = xy_wh_r_2_xy_sigma(pred_decode)
    _, Sigma_t = xy_wh_r_2_xy_sigma(targets_decode)

    # Smooth-L1 norm
    diff = torch.abs(xy_p - xy_t)
    xy_loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                          diff - 0.5 * beta).sum(dim=-1)
    Vb_p = 4 * Sigma_p.det().sqrt()
    Vb_t = 4 * Sigma_t.det().sqrt()
    K = Sigma_p.bmm((Sigma_p + Sigma_t).inverse())
    Sigma = Sigma_p - K.bmm(Sigma_p)
    Vb = 4 * Sigma.det().sqrt()
    Vb = torch.where(torch.isnan(Vb), torch.full_like(Vb, 0), Vb)
    KFIoU = Vb / (Vb_p + Vb_t - Vb + eps)

    if fun == 'ln':
        kf_loss = -torch.log(KFIoU + eps)
    elif fun == 'exp':
        kf_loss = torch.exp(1 - KFIoU) - 1
    else:
        kf_loss = 1 - KFIoU

    loss = (xy_loss + kf_loss).clamp(0)

    return loss


[docs]@ROTATED_LOSSES.register_module() class KFLoss(nn.Module): """Kalman filter based loss. Args: fun (str, optional): The function applied to distance. Defaults to 'log1p'. reduction (str, optional): The reduction method of the loss. Defaults to 'mean'. loss_weight (float, optional): The weight of loss. Defaults to 1.0. Returns: loss (torch.Tensor) """ def __init__(self, fun='none', reduction='mean', loss_weight=1.0, **kwargs): super(KFLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] assert fun in ['none', 'ln', 'exp'] self.fun = fun self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, pred_decode=None, targets_decode=None, reduction_override=None, **kwargs): """Forward function. Args: pred (torch.Tensor): Predicted convexes. target (torch.Tensor): Corresponding gt convexes. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. pred_decode (torch.Tensor): Predicted decode bboxes. targets_decode (torch.Tensor): Corresponding gt decode bboxes. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: loss (torch.Tensor) """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): return (pred * weight).sum() if weight is not None and weight.dim() > 1: assert weight.shape == pred.shape weight = weight.mean(-1) return kfiou_loss( pred, target, fun=self.fun, weight=weight, avg_factor=avg_factor, pred_decode=pred_decode, targets_decode=targets_decode, reduction=reduction, **kwargs) * self.loss_weight
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.