Shortcuts

教程 4: 自定义训练设置

自定义优化设置

自定义 Pytorch 支持的优化器

我们已经支持了全部 Pytorch 自带的优化器,唯一需要修改的就是配置文件中 optimizer 部分。 例如,如果您想使用 ADAM (注意如下操作可能会让模型表现下降),可以使用如下修改:

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

为了修改模型训练的学习率,使用者仅需修改配置文件里 optimizerlr 即可。 使用者可以参考 PyTorch 的 API doc 直接设置参数。

自定义用户自己实现的优化器

1. 定义一个新的优化器

一个自定义的优化器可以这样定义:

假如您想增加一个叫做 MyOptimizer 的优化器,它的参数分别有 a, b, 和 c。 您需要创建一个名为 mmrotate/core/optimizer 的新文件夹;然后参考如下代码段在 mmrotate/core/optimizer/my_optimizer.py 文件中实现新的优化器:

from mmdet.core.optimizer.registry import OPTIMIZERS
from torch.optim import Optimizer


@OPTIMIZERS.register_module()
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)

2. 增加优化器到注册表 (registry)

为了能够使得上述添加的模块被 mmrotate 发现,需要先将该模块添加到主命名空间(main namespace)。

  • 修改 mmrotate/core/optimizer/__init__.py 文件来导入该模块。

    新的被定义的模块应该被导入到 mmrotate/core/optimizer/__init__.py 中,这样注册表才会发现新的模块并添加它:

from .my_optimizer import MyOptimizer
  • 在配置文件中使用 custom_imports 来手动添加该模块

custom_imports = dict(imports=['mmrotate.core.optimizer.my_optimizer'], allow_failed_imports=False)

mmrotate.core.optimizer.my_optimizer 模块将会在程序开始被导入,并且 MyOptimizer 类将会自动注册。 需要注意只有包含 MyOptimizer 类的包 (package) 应当被导入。 而 mmrotate.core.optimizer.my_optimizer.MyOptimizer 不能 被直接导入。

事实上,在这种导入方式下用户可以用完全不同的文件夹结构,只要这一模块的根目录已经被添加到 PYTHONPATH 里面。

3. 在配置文件中指定优化器

之后您可以在配置文件的 optimizer 部分里面使用 MyOptimizer。 在配置文件里,优化器按照如下形式被定义在 optimizer 部分里:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

要使用用户自定义的优化器,这部分应该改成:

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

自定义优化器的构造函数 (constructor)

有些模型的优化器可能有一些特别参数配置,例如批归一化层 (BatchNorm layers) 的权重衰减系数 (weight decay)。 用户可以通过自定义优化器的构造函数去微调这些细粒度参数。

from mmcv.utils import build_from_cfg

from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmrotate.utils import get_root_logger
from .my_optimizer import MyOptimizer


@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer

mmcv 默认的优化器构造函数实现可以参考 这里 ,这也可以作为新的优化器构造函数的模板。

其他配置

优化器未实现的技巧应该通过修改优化器构造函数(如设置基于参数的学习率)或者钩子(hooks)去实现。我们列出一些常见的设置,它们可以稳定或加速模型的训练。 如果您有更多的设置,欢迎在 PR 和 issue 里面提出。

  • 使用梯度裁剪 (gradient clip) 来稳定训练: 一些模型需要梯度裁剪来稳定训练过程。使用方式如下:

    optimizer_config = dict(
        _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
    

    如果您的配置继承了已经设置了 optimizer_config 的基础配置(base config),你可能需要设置 _delete_=True 来覆盖不必要的配置参数。请参考 配置文档 了解更多细节。

  • 使用动量调度加速模型收敛: 我们支持动量规划器(Momentum scheduler),以实现根据学习率调节模型优化过程中的动量设置,这可以使模型以更快速度收敛。 动量规划器经常与学习率规划器(LR scheduler)一起使用,例如下面的配置经常被用于 3D 检测模型训练中以加速收敛。更多细节请参考 CyclicLrUpdaterCyclicMomentumUpdater

    lr_config = dict(
        policy='cyclic',
        target_ratio=(10, 1e-4),
        cyclic_times=1,
        step_ratio_up=0.4,
    )
    momentum_config = dict(
        policy='cyclic',
        target_ratio=(0.85 / 0.95, 1),
        cyclic_times=1,
        step_ratio_up=0.4,
    )
    

自定义训练计划

默认地,我们使用 1x 计划(1x schedule)的步进学习率(step learning rate),这在 MMCV 中被称为 StepLRHook。 我们支持很多其他的学习率规划器,参考 这里 ,例如 CosineAnnealingPoly。下面是一些例子:

  • Poly :

    lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
    
  • ConsineAnnealing :

    lr_config = dict(
        policy='CosineAnnealing',
        warmup='linear',
        warmup_iters=1000,
        warmup_ratio=1.0 / 10,
        min_lr_ratio=1e-5)
    

自定义工作流 (workflow)

工作流是一个专门定义运行顺序和轮数(epochs)的列表。 默认情况下它设置成:

workflow = [('train', 1)]

这是指训练 1 个 epoch。 有时候用户可能想检查一些模型在验证集上的指标,如损失函数值(Loss)和准确性(Accuracy)。 在这种情况下,我们可以将工作流设置为:

[('train', 1), ('val', 1)]

这样以来, 1 个 epoch 训练,1 个 epoch 验证将交替运行。

注意:

  1. 模型参数在验证的阶段不会被自动更新。

  2. 配置文件里的键值 total_epochs 仅控制训练的 epochs 数目,而不会影响验证工作流。

  3. 工作流 [('train', 1), ('val', 1)][('train', 1)] 将不会改变 EvalHook 的行为,因为 EvalHookafter_train_epoch 调用而且验证的工作流仅仅影响通过调用 after_val_epoch 的钩子 (hooks)。因此, [('train', 1), ('val', 1)][('train', 1)] 的区别仅在于 runner 将在每次训练阶段(training epoch)结束后计算在验证集上的损失。

自定义钩 (hooks)

自定义用户自己实现的钩子(hooks)

1. 实现一个新的钩子(hook)

在某些情况下,用户可能需要实现一个新的钩子。 MMRotate 支持训练中的自定义钩子。 因此,用户可以直接在 mmrotate 或其基于 mmdet 的代码库中实现钩子,并通过仅在训练中修改配置来使用钩子。 这里我们举一个例子:在 mmrotate 中创建一个新的钩子并在训练中使用它。

from mmcv.runner import HOOKS, Hook


@HOOKS.register_module()
class MyHook(Hook):

    def __init__(self, a, b):
        pass

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

用户需要根据钩子的功能指定钩子在训练各阶段中( before_run , after_run , before_epoch , after_epoch , before_iter , after_iter)做什么。

2. 注册新的钩子(hook)

接下来我们需要导入 MyHook。如果文件的路径是 mmrotate/core/utils/my_hook.py ,有两种方式导入:

  • 修改 mmrotate/core/utils/__init__.py 文件来导入

    新定义的模块需要在 mmrotate/core/utils/__init__.py 导入,注册表才会发现并添加该模块:

from .my_hook import MyHook
  • 在配置文件中使用 custom_imports 来手动导入

custom_imports = dict(imports=['mmrotate.core.utils.my_hook'], allow_failed_imports=False)

3. 修改配置

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value)
]

您也可以通过配置键值 priority'NORMAL''HIGHEST' 来设置钩子的优先级:

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
]

默认地,钩子的优先级在注册时被设置为 NORMAL

使用 MMCV 中实现的钩子 (hooks)

如果钩子已经在 MMCV 里实现了,您可以直接修改配置文件来使用钩子。

4. 示例: NumClassCheckHook

我们实现了一个自定义的钩子 NumClassCheckHook ,用来检验 head 中的 num_classes 是否与 dataset 中的 CLASSSES 长度匹配。

我们在 default_runtime.py 中对其进行设置。

custom_hooks = [dict(type='NumClassCheckHook')]

修改默认运行挂钩

有一些常见的钩子并不通过 custom_hooks 注册,这些钩子包括:

  • log_config

  • checkpoint_config

  • evaluation

  • lr_config

  • optimizer_config

  • momentum_config

这些钩子中,只有记录器钩子(logger hook)是 VERY_LOW 优先级,其他钩子的优先级为 NORMAL。 前面提到的教程已经介绍了如何修改 optimizer_config , momentum_config 以及 lr_config。 这里我们介绍一下如何处理 log_config , checkpoint_config 以及 evaluation

Checkpoint config

MMCV runner 将使用 checkpoint_config 来初始化 CheckpointHook

checkpoint_config = dict(interval=1)

用户可以设置 max_keep_ckpts 来仅保存一小部分检查点(checkpoint)或者通过设置 save_optimizer 来决定是否保存优化器的状态字典 (state dict of optimizer)。更多使用参数的细节请参考 这里

Log config

log_config 包裹了许多日志钩 (logger hooks) 而且能去设置间隔 (intervals)。现在 MMCV 支持 WandbLoggerHookMlflowLoggerHookTensorboardLoggerHook。 详细的使用请参照 文档

log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(type='TensorboardLoggerHook')
    ])

Evaluation config

evaluation 的配置文件将被用来初始化 EvalHook。 除了 interval 键,其他的像 metric 这样的参数将被传递给 dataset.evaluate()

evaluation = dict(interval=1, metric='bbox')