添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
【stable-diffusion企业级教程04】EMA你走,拥抱16G显存!Xformers是未来!

【stable-diffusion企业级教程04】EMA你走,拥抱16G显存!Xformers是未来!

1、回顾

上一讲我们成功的利用deepspeed框架,将sd模型以fp16的精度训练了起来,显存的消耗也降低到18G左右。

不过有小伙伴在下面评论,说手上只有16G的显卡,有没有办法再降低一些训练时的资源需求呢。

今天我们就介绍另外两个方法,进一步降低整体的显存需求。

2、ema

2.1 ema简单介绍

熟悉股票技术分析的同学应该很熟悉ma(移动平均),简单讲就是随着时间的发展,取前K天值的均值来做为一个指标。

那深度学习中的ema(exponential moving average),可以理解为是一种更新模型权重的方法,通过维持一个影子权重的方法,来对模型参数做“平均”,使得模型在最后的测评集中效果更好。

我个人的理解是,batch_gradient_decent可以看做是不同样本共同决定更新方向;而ema则是跨batch来决定更新幅度。

这里不展开,具体的可以看看这两文章:

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

理解滑动平均(exponential moving average) - wuliytTaotao - 博客园

2.2 模型选择

ok,回到我们的显存上来,这里提到ema是为了说明,利用ema进行更新的模型参数,会和正常进行更新的模型参数不一致,从而会保存两份参数。

这也提现到了模型上面,在这个链接中,提供了下面两个模型来下载,我们之前使用的是 sd-v1-4-full-ema.ckpt 这个模型。

利用我之前分享的repo中的 test_parameters.py ,可以将模型的参数名字打印出来。我们会发现,在vqa和text-encoder之外,模型保存了两份不同的unet模型。而这个,就是导致我们加载模型耗显存的原因。


那解决的方案也很简单,我们的base模型,可以选择下面链接中 sd-v1-4.ckpt 就可以了。

2.3 训练设置

那除了在加载模型时避免加载ema的权重,我们还需要在训练时避免生成ema相关的权重。通过观察代码,我们会发现,ema权重,是通过调用 on_train_batch_end 这个方法来实现的。

main_torch_deepspeed.py
# 4、Start train
device= torch.device(model_engine.local_rank)
for epoch in range(10*6*5): # 800/8 = 100   2*50/gpu/epoch, 300
    for i,bs in enumerate(tqdm(trainloader,desc=f"{epoch}")):
        if fp16:
            bs['image'] = bs['image'].cuda().half()
        else:
            bs['image'] = bs['image'].cuda()
        loss = model.training_step(bs, i)
        model_engine.backward(loss)
        model_engine.step()
        model.on_train_batch_end()
========》