添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
精彩文章免费看

pytorch学习笔记-weight decay 和 learning rate decay

1. Weight decay

Weight decay 是一种正则化方法,大概意思就是在做梯度下降之前,当前模型的 weight 做一定程度的 decay。
上面这个就相当于是 weights 减去下面公式对权重的梯度:
整理一下就是L2正则化:

所以当 weight\_decay' =\frac{weight\_decay}{lr} 的时候,L2正则化和 weight decay 是一样的,因此也会有人说L2正则就是权重衰减。在SGD中的确是这样,但是在 Adam中就不一定了。

使用 weight decay 可以:

  • 防止过拟合
  • 保持权重在一个较小在的值,避免梯度爆炸。因为在原本的 loss 函数上加上了权重值的 L2 范数,在每次迭代时,模不仅会去优化/最小化 loss,还会使模型权重最小化。让权重值保持尽可能小,有利于控制权重值的变化幅度(如果梯度很大,说明模型本身在变化很大,去过拟合样本),从而避免梯度爆炸。
  • 在 pytorch 里可以设置 weight decay。 torch.optim.Optimizer 里, SGD、ASGD 、Adam、RMSprop 等都有weight_decay参数设置:

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    Deep learning basic-weight decay
    关于量化训练的一个小tip: weight-decay

    2. Learning rate decay

    知道梯度下降的,应该都知道学习率的影响,过大过小都会影响到学习的效果。Learning rate decay 的目的是在训练过程中逐渐降低学习率,pytorch 在torch.optim.lr_scheduler 里提供了很多花样。

    Scheduler 的定义在 optimizer之后, 而参数更新应该在一个 epoch 结束之后。

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', verbose=True)
    for epoch in range(10):
       for input,label in dataloader:
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()
    

    权重衰减(weight decay)与学习率衰减(learning rate decay)