Shortcuts

教程 3: 自定义模型

我们大致将模型组件分为了 5 种类型。

  • 主干网络 (Backbone): 通常是一个全卷积网络 (FCN),用来提取特征图,比如残差网络 (ResNet)。也可以是基于视觉 Transformer 的网络,比如 Swin Transformer 等。

  • Neck: 主干网络和任务头 (Head) 之间的连接组件,比如 FPN, ReFPN。

  • 任务头 (Head): 用于某种具体任务(比如边界框预测)的组件。

  • 区域特征提取器 (Roi Extractor): 用于从特征图上提取区域特征的组件,比如 RoI Align Rotated。

  • 损失 (loss): 任务头上用于计算损失函数的组件,比如 FocalLoss, GWDLoss, and KFIoULoss。

开发新的组件

添加新的主干网络

这里,我们以 MobileNet 为例来展示如何开发新组件。

1. 定义一个新的主干网络(以 MobileNet 为例)

新建文件 mmrotate/models/backbones/mobilenet.py

import torch.nn as nn

from mmrotate.models.builder import ROTATED_BACKBONES


@ROTATED_BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

2. 导入模块

你可以将下面的代码添加到 mmrotate/models/backbones/__init__.py 中:

from .mobilenet import MobileNet

或者添加如下代码

custom_imports = dict(
    imports=['mmrotate.models.backbones.mobilenet'],
    allow_failed_imports=False)

到配置文件中以避免修改原始代码。

3. 在你的配置文件中使用该主干网络

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

添加新的 Neck

1. 定义一个 Neck(以 PAFPN 为例)

新建文件 mmrotate/models/necks/pafpn.py

from mmrotate.models.builder import ROTATED_NECKS

@ROTATED_NECKS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 导入该模块

你可以添加下述代码到 mmrotate/models/necks/__init__.py

from .pafpn import PAFPN

或者添加

custom_imports = dict(
    imports=['mmrotate.models.necks.pafpn.py'],
    allow_failed_imports=False)

到配置文件中以避免修改原始代码。

3. 修改配置文件

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

添加新的损失

假设你想添加一个新的损失 MyLoss 用于边界框回归。 为了添加一个新的损失函数,用户需要在 mmrotate/models/losses/my_loss.py 中实现。 装饰器 weighted_loss 可以使损失每个部分加权。

import torch
import torch.nn as nn

from mmrotate.models.builder import ROTATED_LOSSES
from mmdet.models.losses.utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@ROTATED_LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

然后,用户需要把下面的代码加到 mmrotate/models/losses/__init__.py 中。

from .my_loss import MyLoss, my_loss

或者,你可以添加:

custom_imports=dict(
    imports=['mmrotate.models.losses.my_loss'])

到配置文件来实现相同的目的。

因为 MyLoss 是用于回归的,你需要在 Head 中修改 loss_bbox 字段:

loss_bbox=dict(type='MyLoss', loss_weight=1.0))
Read the Docs v: latest
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.