Source code for mmrotate.models.utils.ripool
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
[docs]class RotationInvariantPooling(nn.Module):
"""Rotating invariant pooling module.
Args:
nInputPlane (int): The number of Input plane.
nOrientation (int, optional): The number of oriented channels.
"""
def __init__(self, nInputPlane, nOrientation=8):
super(RotationInvariantPooling, self).__init__()
self.nInputPlane = nInputPlane
self.nOrientation = nOrientation
[docs] def forward(self, x):
"""Forward function."""
N, c, h, w = x.size()
x = x.view(N, -1, self.nOrientation, h, w)
x, _ = x.max(dim=2, keepdim=False)
return x