Shortcuts

Source code for mmrotate.models.losses.kld_reppoints_loss

# Copyright (c) SJTU. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weighted_loss

from mmrotate.core import GaussianMixture, gt2gaussian
from ..builder import ROTATED_LOSSES


def kld_single2single(g1, g2):
    """Compute Kullback-Leibler Divergence.

    Args:
        g1 (dict[str, torch.Tensor]): Gaussian distribution 1.
        g2 (torch.Tensor): Gaussian distribution 2.

    Returns:
        torch.Tensor: Kullback-Leibler Divergence.
    """
    p_mu = g1.mu
    p_var = g1.var
    assert p_mu.dim() == 3 and p_mu.size()[1] == 1
    assert p_var.dim() == 4 and p_var.size()[1] == 1
    p_mu = p_mu.squeeze(1)
    p_var = p_var.squeeze(1)
    t_mu, t_var = g2
    delta = (p_mu - t_mu).unsqueeze(-1)
    t_inv = torch.inverse(t_var)
    term1 = delta.transpose(-1, -2).matmul(t_inv).matmul(delta).squeeze(-1)
    term2 = torch.diagonal(
        t_inv.matmul(p_var),
        dim1=-2,
        dim2=-1).sum(dim=-1, keepdim=True) + \
        torch.log(torch.det(t_var) / torch.det(p_var)).reshape(-1, 1)

    return 0.5 * (term1 + term2) - 1


@weighted_loss
def kld_loss(pred, target, eps=1e-6):
    """Kullback-Leibler Divergence loss.

    Args:
        pred (torch.Tensor): Convexes with shape (N, 9, 2).
        target (torch.Tensor): Polygons with shape (N, 4, 2).
        eps (float): Defaults to 1e-6.

    Returns:
        torch.Tensor: Kullback-Leibler Divergence loss.
    """
    pred = pred.reshape(-1, 9, 2)
    target = target.reshape(-1, 4, 2)

    assert pred.size()[0] == target.size()[0] and target.numel() > 0
    gmm = GaussianMixture(n_components=1, requires_grad=True)
    gmm.fit(pred)
    kld = kld_single2single(gmm, gt2gaussian(target))
    kl_agg = kld.clamp(min=eps)
    loss = 1 - 1 / (2 + torch.sqrt(kl_agg))

    return loss


[docs]@ROTATED_LOSSES.register_module() class KLDRepPointsLoss(nn.Module): """Kullback-Leibler Divergence loss for RepPoints. Args: eps (float): Defaults to 1e-6. reduction (str, optional): The reduction method of the loss. Defaults to 'mean'. loss_weight (float, optional): The weight of loss. Defaults to 1.0. """ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): super(KLDRepPointsLoss, self).__init__() self.eps = eps self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function. Args: pred (torch.Tensor): Predicted convexes. target (torch.Tensor): Corresponding gt convexes. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: loss (torch.Tensor) """ if weight is not None and not torch.any(weight > 0): return (pred * weight.unsqueeze(-1)).sum() # 0 assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_bbox = self.loss_weight * kld_loss( pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_bbox
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.