Shortcuts

Source code for mmrotate.models.backbones.re_resnet

# Copyright (c) OpenMMLab. All rights reserved.
# Modified from csuhan: https://github.com/csuhan/ReDet
import warnings

import e2cnn.nn as enn
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from ..builder import ROTATED_BACKBONES
from ..utils import (build_enn_divide_feature, build_enn_norm_layer,
                     build_enn_trivial_feature, ennAvgPool, ennConv,
                     ennMaxPool, ennReLU, ennTrivialConv)


class BasicBlock(enn.EquivariantModule):
    """BasicBlock for ReResNet.

    Args:
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        expansion (int): The ratio of ``out_channels/mid_channels`` where
            ``mid_channels`` is the output channels of conv1. This is a
            reserved argument in BasicBlock and should always be 1. Default: 1.
        stride (int): stride of the block. Default: 1
        dilation (int): dilation of convolution. Default: 1
        downsample (nn.Module): downsample operation on identity branch.
            Default: None.
        style (str): `pytorch` or `caffe`. It is unused and reserved for
            unified API with Bottleneck.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: None
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN')
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 expansion=1,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 init_cfg=None):
        super(BasicBlock, self).__init__()
        self.in_type = build_enn_divide_feature(in_channels)
        self.out_type = build_enn_divide_feature(out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.expansion = expansion
        assert self.expansion == 1
        assert out_channels % expansion == 0
        self.mid_channels = out_channels // expansion
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg

        self.norm1_name, norm1 = build_enn_norm_layer(
            self.mid_channels, postfix=1)
        self.norm2_name, norm2 = build_enn_norm_layer(out_channels, postfix=2)

        self.conv1 = ennConv(
            in_channels,
            self.mid_channels,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
        self.add_module(self.norm1_name, norm1)
        self.relu1 = ennReLU(self.mid_channels)
        self.conv2 = ennConv(
            self.mid_channels, out_channels, 3, padding=1, bias=False)
        self.add_module(self.norm2_name, norm2)

        self.relu2 = ennReLU(out_channels)
        self.downsample = downsample

    @property
    def norm1(self):
        """Get normalizion layer's name."""
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        """Get normalizion layer's name."""
        return getattr(self, self.norm2_name)

    def forward(self, x):
        """Forward function of BasicBlock."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu1(out)

            out = self.conv2(out)
            out = self.norm2(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu2(out)

        return out

    def evaluate_output_shape(self, input_shape):
        """Evaluate output shape."""
        assert len(input_shape) == 4
        assert input_shape[1] == self.in_type.size
        if self.downsample is not None:
            return self.downsample.evaluate_output_shape(input_shape)
        else:
            return input_shape


class Bottleneck(enn.EquivariantModule):
    """Bottleneck block for ReResNet.

    Args:
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        expansion (int): The ratio of ``out_channels/mid_channels`` where
            ``mid_channels`` is the input/output channels of conv2. Default: 4.
        stride (int): stride of the block. Default: 1
        dilation (int): dilation of convolution. Default: 1
        downsample (nn.Module): downsample operation on identity branch.
            Default: None.
        style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
            stride-two layer is the 3x3 conv layer, otherwise the stride-two
            layer is the first 1x1 conv layer. Default: "pytorch".
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: None
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN')
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 expansion=4,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 init_cfg=None):
        super(Bottleneck, self).__init__()
        assert style in ['pytorch', 'caffe']
        self.in_type = build_enn_divide_feature(in_channels)
        self.out_type = build_enn_divide_feature(out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.expansion = expansion
        assert out_channels % expansion == 0
        self.mid_channels = out_channels // expansion
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        if self.style == 'pytorch':
            self.conv1_stride = 1
            self.conv2_stride = stride
        else:
            self.conv1_stride = stride
            self.conv2_stride = 1

        self.norm1_name, norm1 = build_enn_norm_layer(
            self.mid_channels, postfix=1)
        self.norm2_name, norm2 = build_enn_norm_layer(
            self.mid_channels, postfix=2)
        self.norm3_name, norm3 = build_enn_norm_layer(out_channels, postfix=3)

        self.conv1 = ennConv(
            in_channels,
            self.mid_channels,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
        self.add_module(self.norm1_name, norm1)
        self.relu1 = ennReLU(self.mid_channels)
        self.conv2 = ennConv(
            self.mid_channels,
            self.mid_channels,
            kernel_size=3,
            stride=self.conv2_stride,
            padding=dilation,
            dilation=dilation,
            bias=False)

        self.add_module(self.norm2_name, norm2)
        self.relu2 = ennReLU(self.mid_channels)
        self.conv3 = ennConv(
            self.mid_channels, out_channels, kernel_size=1, bias=False)
        self.add_module(self.norm3_name, norm3)
        self.relu3 = ennReLU(out_channels)

        self.downsample = downsample

    @property
    def norm1(self):
        """Get normalizion layer's name."""
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        """Get normalizion layer's name."""
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        """Get normalizion layer's name."""
        return getattr(self, self.norm3_name)

    def forward(self, x):
        """Forward function of Bottleneck."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu1(out)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu2(out)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu3(out)

        return out

    def evaluate_output_shape(self, input_shape):
        """Evaluate output shape."""
        assert len(input_shape) == 4
        assert input_shape[1] == self.in_type.size
        if self.downsample is not None:
            return self.downsample.evaluate_output_shape(input_shape)
        else:
            return input_shape


def get_expansion(block, expansion=None):
    """Get the expansion of a residual block.

    The block expansion will be obtained by the following order:

    1. If ``expansion`` is given, just return it.
    2. If ``block`` has the attribute ``expansion``, then return
       ``block.expansion``.
    3. Return the default value according the the block type:
       1 for ``BasicBlock`` and 4 for ``Bottleneck``.

    Args:
        block (class): The block class.
        expansion (int | None): The given expansion ratio.

    Returns:
        int: The expansion of the block.
    """
    if isinstance(expansion, int):
        assert expansion > 0
    elif expansion is None:
        if hasattr(block, 'expansion'):
            expansion = block.expansion
        elif issubclass(block, BasicBlock):
            expansion = 1
        elif issubclass(block, Bottleneck):
            expansion = 4
        else:
            raise TypeError(f'expansion is not specified for {block.__name__}')
    else:
        raise TypeError('expansion must be an integer or None')

    return expansion


class ResLayer(nn.Sequential):
    """ResLayer to build ReResNet style backbone.

    Args:
        block (nn.Module): Residual block used to build ResLayer.
        num_blocks (int): Number of blocks.
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        expansion (int, optional): The expansion for BasicBlock/Bottleneck.
            If not specified, it will firstly be obtained via
            ``block.expansion``. If the block has no attribute "expansion",
            the following default values will be used: 1 for BasicBlock and
            4 for Bottleneck. Default: None.
        stride (int): stride of the first block. Default: 1.
        avg_down (bool): Use AvgPool instead of stride conv when
            downsampling in the bottleneck. Default: False
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: None
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN')
    """

    def __init__(self,
                 block,
                 num_blocks,
                 in_channels,
                 out_channels,
                 expansion=None,
                 stride=1,
                 avg_down=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        self.block = block
        self.expansion = get_expansion(block, expansion)

        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = []
            conv_stride = stride
            if avg_down and stride != 1:
                conv_stride = 1
                downsample.append(
                    ennAvgPool(
                        in_channels,
                        kernel_size=stride,
                        stride=stride,
                        ceil_mode=True))
            downsample.extend([
                ennConv(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=conv_stride,
                    bias=False),
                build_enn_norm_layer(out_channels)[1]
            ])
            downsample = enn.SequentialModule(*downsample)

        layers = []
        layers.append(
            block(
                in_channels=in_channels,
                out_channels=out_channels,
                expansion=self.expansion,
                stride=stride,
                downsample=downsample,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                **kwargs))
        in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(
                block(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    expansion=self.expansion,
                    stride=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    **kwargs))
        super(ResLayer, self).__init__(*layers)


[docs]@ROTATED_BACKBONES.register_module() class ReResNet(BaseModule): """ReResNet backbone. Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_ for details. Args: depth (int): Network depth, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Default: 3. stem_channels (int): Output channels of the stem layer. Default: 64. base_channels (int): Middle channels of the first stage. Default: 64. num_stages (int): Stages of the network. Default: 4. strides (Sequence[int]): Strides of the first block of each stage. Default: ``(1, 2, 2, 2)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. out_indices (Sequence[int]): Output from which stages. If only one stage is specified, a single tensor (feature map) is returned, otherwise multiple stages are specified, a tuple of tensors will be returned. Default: ``(3, )``. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Default: False. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1. conv_cfg (dict | None): The config dict for conv layers. Default: None. norm_cfg (dict): The config dict for norm layers. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True. """ arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, depth, in_channels=3, stem_channels=64, base_channels=64, expansion=None, num_stages=4, strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(3, ), style='pytorch', deep_stem=False, avg_down=False, frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, zero_init_residual=True, pretrained=None, init_cfg=None): super(ReResNet, self).__init__() self.in_type = build_enn_trivial_feature(3) assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be setting at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ dict(type='Kaiming', layer='Conv2d'), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ] else: raise TypeError('pretrained must be a str or None') if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.stem_channels = stem_channels self.base_channels = base_channels self.num_stages = num_stages assert num_stages >= 1 and num_stages <= 4 self.strides = strides self.dilations = dilations assert len(strides) == len(dilations) == num_stages self.out_indices = out_indices assert max(out_indices) < num_stages self.style = style self.deep_stem = deep_stem self.avg_down = avg_down self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_cp = with_cp self.norm_eval = norm_eval self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.expansion = get_expansion(self.block, expansion) self._make_stem_layer(in_channels, stem_channels) self.res_layers = [] _in_channels = stem_channels _out_channels = base_channels * self.expansion for i, num_blocks in enumerate(self.stage_blocks): stride = strides[i] dilation = dilations[i] res_layer = self.make_res_layer( block=self.block, num_blocks=num_blocks, in_channels=_in_channels, out_channels=_out_channels, expansion=self.expansion, stride=stride, dilation=dilation, style=self.style, avg_down=self.avg_down, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg) _in_channels = _out_channels _out_channels *= 2 layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self._freeze_stages() self.feat_dim = res_layer[-1].out_channels
[docs] def make_res_layer(self, **kwargs): """Build Reslayer.""" return ResLayer(**kwargs)
@property def norm1(self): """Get normalizion layer's name.""" return getattr(self, self.norm1_name) def _make_stem_layer(self, in_channels, stem_channels): """Build stem layer.""" if not self.deep_stem: self.conv1 = ennTrivialConv( in_channels, stem_channels, kernel_size=7, stride=2, padding=3) self.norm1_name, norm1 = build_enn_norm_layer( stem_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = ennReLU(stem_channels) self.maxpool = ennMaxPool( stem_channels, kernel_size=3, stride=2, padding=1) def _freeze_stages(self): """Freeze stages.""" if self.frozen_stages >= 0: if not self.deep_stem: self.norm1.eval() for m in [self.conv1, self.norm1]: for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False
[docs] def forward(self, x): """Forward function of ReResNet.""" if not self.deep_stem: x = enn.GeometricTensor(x, self.in_type) x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] else: return tuple(outs)
[docs] def train(self, mode=True): """Train function of ReResNet.""" super(ReResNet, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()
Read the Docs v: v0.3.4
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.