前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >FlashAttention图解(如何加速Attention)

FlashAttention图解(如何加速Attention)

作者头像
BBuf
发布2023-10-30 19:24:05
1.2K0
发布2023-10-30 19:24:05
举报
文章被收录于专栏:GiantPandaCVGiantPandaCV

作者丨Austin

来源丨https://zhuanlan.zhihu.com/p/626079753

编辑丨GiantPandaCV


Motivation

当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。

标准Attention的中间结果S,P(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为

。本文分析:

FlashAttention: 对HBM访问的次数为

Attention: 对HBM访问的次数为

往往

(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU):

GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。

综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。

阅读本文需要了解的符号定义:

Method

Attention

计算流程图如下:

FlashAttention

只是部分结果,如下图所示,外循环 j 是横向(特征维 d )移动的,内循环 i 是纵向(序列维 N )移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。

作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。

Theorem 1. FlashAttention的FLOPs为

,除了input和output,额外需要的内存为


- The End -

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-10-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Method
    • Attention
      • FlashAttention
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档