Shortcuts

教程 4: 自定义训练设置

自定义优化设置

自定义 Pytorch 支持的优化器

我们已经支持了全部 Pytorch 自带的优化器,唯一需要修改的就是配置文件中 optimizer 部分。 例如,如果您想使用 ADAM (注意如下操作可能会让模型表现下降),可以使用如下修改:

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

为了修改模型训练的学习率,使用者仅需修改配置文件里 optimizerlr 即可。 使用者可以参考 PyTorch 的 API doc 直接设置参数。

自定义用户自己实现的优化器

1. 定义一个新的优化器

一个自定义的优化器可以这样定义:

假如您想增加一个叫做 MyOptimizer 的优化器,它的参数分别有 a, b, 和 c。 您需要创建一个名为 mmrotate/core/optimizer 的新文件夹;然后参考如下代码段在 mmrotate/core/optimizer/my_optimizer.py 文件中实现新的优化器:

from mmdet.core.optimizer.registry import OPTIMIZERS
from torch.optim import Optimizer


@OPTIMIZERS.register_module()
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)

2. 增加优化器到注册表 (registry)

为了能够使得上述添加的模块被 mmrotate 发现,需要先将该模块添加到主命名空间(main namespace)。

  • 修改 mmrotate/core/optimizer/__init__.py 文件来导入该模块。

    新的被定义的模块应该被导入到 mmrotate/core/optimizer/__init__.py 中,这样注册表才会发现新的模块并添加它:

from .my_optimizer import MyOptimizer
  • 在配置文件中使用 custom_imports 来手动添加该模块

custom_imports = dict(imports=['mmrotate.core.optimizer.my_optimizer'], allow_failed_imports=False)

mmrotate.core.optimizer.my_optimizer 模块将会在程序开始被导入,并且 MyOptimizer 类将会自动注册。 需要注意只有包含 MyOptimizer 类的包 (package) 应当被导入。 而 mmrotate.core.optimizer.my_optimizer.MyOptimizer 不能 被直接导入。

事实上,在这种导入方式下用户可以用完全不同的文件夹结构,只要这一模块的根目录已经被添加到 PYTHONPATH 里面。

3. 在配置文件中指定优化器

之后您可以在配置文件的 optimizer 部分里面使用 MyOptimizer。 在配置文件里,优化器按照如下形式被定义在 optimizer 部分里:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

要使用用户自定义的优化器,这部分应该改成:

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

自定义优化器的构造函数 (constructor)

有些模型的优化器可能有一些特别参数配置,例如批归一化层 (BatchNorm layers) 的权重衰减系数 (weight decay)。 用户可以通过自定义优化器的构造函数去微调这些细粒度参数。

from mmcv.utils import build_from_cfg

from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmrotate.utils import get_root_logger
from .my_optimizer import MyOptimizer


@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer

mmcv 默认的优化器构造函数实现可以参考 这里 ,这也可以作为新的优化器构造函数的模板。

其他配置

优化器未实现的技巧应该通过修改优化器构造函数(如设置基于参数的学习率)或者钩子(hooks)去实现。我们列出一些常见的设置,它们可以稳定或加速模型的训练。 如果您有更多的设置,欢迎在 PR 和 issue 里面提出。

  • 使用梯度裁剪 (gradient clip) 来稳定训练: 一些模型需要梯度裁剪来稳定训练过程。使用方式如下:

    optimizer_config = dict(
        _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
    

    如果您的配置继承了已经设置了 optimizer_config 的基础配置(base config),你可能需要设置 _delete_=True 来覆盖不必要的配置参数。请参考 配置文档 了解更多细节。

  • 使用动量调度加速模型收敛: 我们支持动量规划器(Momentum scheduler),以实现根据学习率调节模型优化过程中的动量设置,这可以使模型以更快速度收敛。 动量规划器经常与学习率规划器(LR scheduler)一起使用,例如下面的配置经常被用于 3D 检测模型训练中以加速收敛。更多细节请参考 CyclicLrUpdaterCyclicMomentumUpdater

    lr_config = dict(
        policy='cyclic',
        target_ratio=(10, 1e-4),
        cyclic_times=1,
        step_ratio_up=0.4,
    )
    momentum_config = dict(
        policy='cyclic',
        target_ratio=(0.85 / 0.95, 1),
        cyclic_times=1,
        step_ratio_up=0.4,
    )
    

自定义训练计划

默认地,我们使用 1x 计划(1x schedule)的步进学习率(step learning rate),这在 MMCV 中被称为 StepLRHook。 我们支持很多其他的学习率规划器,参考 这里 ,例如 CosineAnnealingPoly。下面是一些例子:

  • Poly :

    lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
    
  • ConsineAnnealing :

    lr_config = dict(
        policy='CosineAnnealing',
        warmup='linear',
        warmup_iters=1000,
        warmup_ratio=1.0 / 10,
        min_lr_ratio=1e-5)
    

自定义工作流 (workflow)

工作流是一个专门定义运行顺序和轮数(epochs)的列表。 默认情况下它设置成:

workflow = [('train', 1)]

这是指训练 1 个 epoch。 有时候用户可能想检查一些模型在验证集上的指标,如损失函数值(Loss)和准确性(Accuracy)。 在这种情况下,我们可以将工作流设置为:

[('train', 1), ('val', 1)]

这样以来, 1 个 epoch 训练,1 个 epoch 验证将交替运行。

注意:

  1. 模型参数在验证的阶段不会被自动更新。

  2. 配置文件里的键值 total_epochs 仅控制训练的 epochs 数目,而不会影响验证工作流。

  3. 工作流 [('train', 1), ('val', 1)][('train', 1)] 将不会改变 EvalHook 的行为,因为 EvalHookafter_train_epoch 调用而且验证的工作流仅仅影响通过调用 after_val_epoch 的钩子 (hooks)。因此, [('train', 1), ('val', 1)][('train', 1)] 的区别仅在于 runner 将在每次训练阶段(training epoch)结束后计算在验证集上的损失。

自定义钩 (hooks)

自定义用户自己实现的钩子(hooks)

1. 实现一个新的钩子(hook)

在某些情况下,用户可能需要实现一个新的钩子。 MMRotate 支持训练中的自定义钩子。 因此,用户可以直接在 mmrotate 或其基于 mmdet 的代码库中实现钩子,并通过仅在训练中修改配置来使用钩子。 这里我们举一个例子:在 mmrotate 中创建一个新的钩子并在训练中使用它。

from mmcv.runner import HOOKS, Hook


@HOOKS.register_module()
class MyHook(Hook):

    def __init__(self, a, b):
        pass

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

用户需要根据钩子的功能指定钩子在训练各阶段中( before_run , after_run , before_epoch , after_epoch , before_iter , after_iter)做什么。

2. 注册新的钩子(hook)

接下来我们需要导入 MyHook。如果文件的路径是 mmrotate/core/utils/my_hook.py ,有两种方式导入:

  • 修改 mmrotate/core/utils/__init__.py 文件来导入

    新定义的模块需要在 mmrotate/core/utils/__init__.py 导入,注册表才会发现并添加该模块:

from .my_hook import MyHook
  • 在配置文件中使用 custom_imports 来手动导入

custom_imports = dict(imports=['mmrotate.core.utils.my_hook'], allow_failed_imports=False)

3. 修改配置

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value)
]

您也可以通过配置键值 priority'NORMAL''HIGHEST' 来设置钩子的优先级:

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
]

默认地,钩子的优先级在注册时被设置为 NORMAL

使用 MMCV 中实现的钩子 (hooks)

如果钩子已经在 MMCV 里实现了,您可以直接修改配置文件来使用钩子。

4. 示例: NumClassCheckHook

我们实现了一个自定义的钩子 NumClassCheckHook ,用来检验 head 中的 num_classes 是否与 dataset 中的 CLASSSES 长度匹配。

我们在 default_runtime.py 中对其进行设置。

custom_hooks = [dict(type='NumClassCheckHook')]

修改默认运行挂钩

有一些常见的钩子并不通过 custom_hooks 注册,这些钩子包括:

  • log_config

  • checkpoint_config

  • evaluation

  • lr_config

  • optimizer_config

  • momentum_config

这些钩子中,只有记录器钩子(logger hook)是 VERY_LOW 优先级,其他钩子的优先级为 NORMAL。 前面提到的教程已经介绍了如何修改 optimizer_config , momentum_config 以及 lr_config。 这里我们介绍一下如何处理 log_config , checkpoint_config 以及 evaluation

Checkpoint config

MMCV runner 将使用 checkpoint_config 来初始化 CheckpointHook

checkpoint_config = dict(interval=1)

用户可以设置 max_keep_ckpts 来仅保存一小部分检查点(checkpoint)或者通过设置 save_optimizer 来决定是否保存优化器的状态字典 (state dict of optimizer)。更多使用参数的细节请参考 这里

Log config

log_config 包裹了许多日志钩 (logger hooks) 而且能去设置间隔 (intervals)。现在 MMCV 支持 WandbLoggerHookMlflowLoggerHookTensorboardLoggerHook。 详细的使用请参照 文档

log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(type='TensorboardLoggerHook')
    ])

Evaluation config

evaluation 的配置文件将被用来初始化 EvalHook。 除了 interval 键,其他的像 metric 这样的参数将被传递给 dataset.evaluate()

evaluation = dict(interval=1, metric='bbox')
Read the Docs v: latest
Versions
latest
stable
1.x
v1.0.0rc0
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.