添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

pytorch里巧用optimizer.zero_grad增大batchsize

在pytorch里,一个常规训练操作如下:

for i, minibatch in enumerate(tr_dataloader):
    features, labels = minibatch
    optimizer.zero_grad()
    loss = model(features, labels)
    loss.backward()
    optimizer.step()

那么问题来了,一定要每次从dataloader 取minibatch个数都要调用一次optimizer.zero_grad()吗?

首先给出结论:

  1. 常规情况下,每个batch需要调用一次optimizer.zero_grad函数,把参数的梯度清零
  2. 也可以多个batch 只调用一次optimizer.zero_grad函数。这样相当于增大了batch_size

即通过如下修改可以在不增加显存消耗的情况下使实际的batch_size增大N倍:

for i, minibatch in enumerate(tr_dataloader):
    features, labels = minibatch
    loss = model(features, labels)
    loss.backward()
    if 0 == i % N:
        optimizer.step()
        optimizer.zero_grad()

论证如下:

下图来自于 pytorch 手册 ,其中和本文相关的一句为:

This function accumulates gradients in the leaves - you might need to zero them before calling it.

关键词是accumulates, 也就是说求导模块autograd在给参数求导时,不是直接赋值,而是累加,即:

w.grad = w.grad + grad_current

下面代码可以进一步验证这一猜想:

import torch
from torch.autograd import Variable
x = Variable(torch.Tensor([[0]]), requires_grad=True)
y1 = x.sin()·
y1.backward()
print(x.grad) # shows 1
y2 = x.sin()·
y2.backward()
print(x.grad) # shows 2

正弦函数在0点处的导数为1, 经过两次累加,最后x.grad 为2, 符合预期。

上面的代码还反映了pytorch 动态图的特点, 即y1.backward() 已经执行,y2 = x.sin() 又基于x做了一次前向操作。所以在pytorch里前向操作可以和后向操作交替进行,梯度还会正确的累加起来。这应该就是pytorch动态图的优势所在吧。

回到题目中的问题,如果一个batch之后没有调用optimizer.zero_grad(), 那么这个batch引入的梯度会暂存在参数w.grad中,下一个batch带来的grad会累加到w.grad里,这样相当于增大了batch_size. 那么这种累加什么时候会结束呢?就是调用optimzer.zero_grad()时,这个函数告诉pytorch, 不要再管上一个batch了,这是一个新的batch了。

其实在pytorch里,不仅要手动调用optimizer.zero_grad(), 还要手动调用loss.backward().从high level 来看,神经网络一次梯度更新包含三个步骤:

  1. 前向运算
  2. 后向梯度计算; 调用loss.backward(), 在pytorch里梯度会累加
  3. 可训练参数更新; 调用optimizer.step, 更新权重

如果多次执行 loss.backward(), 而只执行一次optimizer.step(), 由于梯度累加,参数更新时用了前面多次loss.backward()的所有梯度信息。再执行一次optimizer.zero_grad(), 表示参数更新完毕,重置梯度。

所以三个函数的调用次数比应为:

 loss.backward(): optimizer.step() : optimizer.zero_grad() = N : 1 : 1

其中真正的batch_size,又loss.backward()执行的次数有关, N越大, 真正起作用的batch_size 越大。可以利用这个特点,在一些显存有限的机器上增大batch_size。比如将文章开头的代码修改如下,可以在不增加显存消耗的情况下,使batch_size增大N倍:

for i, minibatch in enumerate(tr_dataloader):