前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >大模型高效训练基础知识:梯度检查点(Gradient Checkpointing)

大模型高效训练基础知识:梯度检查点(Gradient Checkpointing)

作者头像
Steve Wang
发布2023-10-12 09:38:15
7710
发布2023-10-12 09:38:15
举报
文章被收录于专栏:从流域到海域从流域到海域

prerequiste: 大模型训练基础知识:梯度累积(Gradient Accumulationn)

梯度检查点(Gradient Checkpointing)

如今(2023年)大模型的参数量巨大,即使将batch_size设置为1并使用梯度累积的方式更新,也仍然会OOM。原因是通常在计算梯度时,我们需要将所有前向传播时的激活值保存下来,这消耗大量显存。还有另外一种延迟计算的思路,丢掉前向传播时的激活值,在计算梯度时需要哪部分的激活值就重新计算哪部分的激活值,这样做倒是解决了显存不足的问题,但加大了计算量同时也拖慢了训练。

梯度检查点(Gradient Checkpointing)在上述两种方式之间取了一个平衡,这种方法采用了一种策略选择了计算图上的一部分激活值保存下来,其余部分丢弃,这样被丢弃的那一部分激活值需要在计算梯度时重新计算。

下面这个动图展示了一种简单策略:前向传播过程中计算节点的激活值并保存,计算下一个节点完成后丢弃中间节点的激活值,反向传播时如果有保存下来的梯度就直接使用,如果没有就使用保存下来的前一个节点的梯度重新计算当前节点的梯度再使用。

在这里插入图片描述
在这里插入图片描述

Transformer框架开启梯度检查点非常简单,仅需在TrainingArguments中指定gradient checkpoint为True即可:

代码语言:javascript
复制
training_args = TrainingArguments(
    per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, **default_args
)

trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
参考文献

1.Gradient Checkpointing 2.pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023-07-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 梯度检查点(Gradient Checkpointing)
  • 参考文献
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档