Source code for mmrotate.models.losses.kld_reppoints_loss
# Copyright (c) SJTU. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weighted_loss
from mmrotate.core import GaussianMixture, gt2gaussian
from ..builder import ROTATED_LOSSES
def kld_single2single(g1, g2):
"""Compute Kullback-Leibler Divergence.
Args:
g1 (dict[str, torch.Tensor]): Gaussian distribution 1.
g2 (torch.Tensor): Gaussian distribution 2.
Returns:
torch.Tensor: Kullback-Leibler Divergence.
"""
p_mu = g1.mu
p_var = g1.var
assert p_mu.dim() == 3 and p_mu.size()[1] == 1
assert p_var.dim() == 4 and p_var.size()[1] == 1
p_mu = p_mu.squeeze(1)
p_var = p_var.squeeze(1)
t_mu, t_var = g2
delta = (p_mu - t_mu).unsqueeze(-1)
t_inv = torch.inverse(t_var)
term1 = delta.transpose(-1, -2).matmul(t_inv).matmul(delta).squeeze(-1)
term2 = torch.diagonal(
t_inv.matmul(p_var),
dim1=-2,
dim2=-1).sum(dim=-1, keepdim=True) + \
torch.log(torch.det(t_var) / torch.det(p_var)).reshape(-1, 1)
return 0.5 * (term1 + term2) - 1
@weighted_loss
def kld_loss(pred, target, eps=1e-6):
"""Kullback-Leibler Divergence loss.
Args:
pred (torch.Tensor): Convexes with shape (N, 9, 2).
target (torch.Tensor): Polygons with shape (N, 4, 2).
eps (float): Defaults to 1e-6.
Returns:
torch.Tensor: Kullback-Leibler Divergence loss.
"""
pred = pred.reshape(-1, 9, 2)
target = target.reshape(-1, 4, 2)
assert pred.size()[0] == target.size()[0] and target.numel() > 0
gmm = GaussianMixture(n_components=1, requires_grad=True)
gmm.fit(pred)
kld = kld_single2single(gmm, gt2gaussian(target))
kl_agg = kld.clamp(min=eps)
loss = 1 - 1 / (2 + torch.sqrt(kl_agg))
return loss
[docs]@ROTATED_LOSSES.register_module()
class KLDRepPointsLoss(nn.Module):
"""Kullback-Leibler Divergence loss for RepPoints.
Args:
eps (float): Defaults to 1e-6.
reduction (str, optional): The reduction method of the
loss. Defaults to 'mean'.
loss_weight (float, optional): The weight of loss. Defaults to 1.0.
"""
def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
super(KLDRepPointsLoss, self).__init__()
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
[docs] def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.
Args:
pred (torch.Tensor): Predicted convexes.
target (torch.Tensor): Corresponding gt convexes.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
loss (torch.Tensor)
"""
if weight is not None and not torch.any(weight > 0):
return (pred * weight.unsqueeze(-1)).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * kld_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox